<a href="https://colab.research.google.com/github/JSK2022/RandomForest-for-Recurrent-Events/blob/Thesis-Code/JS_simulation_code_230923.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scikit-survival

In [None]:
import numpy as np
import pandas as pd


path="/Users/jeongsookim/Downloads"
data = pd.read_csv(f"{path}/simuDat.csv")

ids = data['id'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
x = data[['group','x1','gender']].values

### RisksetCounter

In [None]:
class RisksetCounter:
    def __init__(self, ids, time_start, time_stop, event):
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.all_unique_times = np.unique(time_stop)
        self.n_unique_times = len(self.all_unique_times)

        self.n_at_risk = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events = np.zeros(self.n_unique_times, dtype=np.int64)
        self.set_data()

        self.state_stack = []

    def set_data(self):
        unique_ids = set(self.ids)
        for t_idx, t in enumerate(self.all_unique_times):
            self.n_at_risk[t_idx] = sum([self.Y_i(id_, t_idx) for id_ in unique_ids])
            self.n_events[t_idx] = sum([self.dN_bar_i(id_, t_idx) for id_ in unique_ids])

    def Y_i(self, id_, t_idx):
        if t_idx >= len(self.all_unique_times):
            return 0
        time_at_t_idx = self.all_unique_times[t_idx]
        indices = (self.ids == id_) & (time_at_t_idx <= self.time_stop)
        return np.any(indices)

    def dN_bar_i(self, id_, t_idx):
        if t_idx >= len(self.all_unique_times):
            return 0
        time_at_t_idx = self.all_unique_times[t_idx]
        indices = (self.ids == id_) & (time_at_t_idx == self.time_stop) & (self.event == 1)
        return np.any(indices)

    def save_state(self):
        self.state_stack.append((self.ids.copy(), self.time_start.copy(), self.time_stop.copy(), self.event.copy(), self.n_at_risk.copy(), self.n_events.copy()))

    def load_state(self):
        if self.state_stack:
            self.ids, self.time_start, self.time_stop, self.event, self.n_at_risk, self.n_events = self.state_stack.pop()

    def update(self, new_ids, new_time_start, new_time_stop, new_event):
        # Save the current state
        self.save_state()

        # Compute the intersection of data
        mask = np.isin(self.ids, new_ids)

        # Extract data of the intersection
        updated_ids = self.ids[mask]
        updated_time_start = self.time_start[mask]
        updated_time_stop = self.time_stop[mask]
        updated_event = self.event[mask]

        # Update object variables based on the intersection data
        self.ids = updated_ids
        self.time_start = updated_time_start
        self.time_stop = updated_time_stop
        self.event = updated_event

        # Recalculate unique times based on the updated data
        self.all_unique_times = np.unique(np.concatenate([self.time_start, self.time_stop]))
        self.n_unique_times = len(self.all_unique_times)

        # Resize the n_at_risk and n_events arrays based on the updated unique times
        self.n_at_risk = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events = np.zeros(self.n_unique_times, dtype=np.int64)

        # Update the n_at_risk and n_events arrays
        unique_ids = set(self.ids)  # Extract unique IDs to avoid redundant calculations
        for t_idx, t in enumerate(self.all_unique_times):
            self.n_at_risk[t_idx] = sum([self.Y_i(id_, t_idx) for id_ in unique_ids])
            self.n_events[t_idx] = sum([self.dN_bar_i(id_, t_idx) for id_ in unique_ids])

    def reset(self):
        self.load_state()

    def copy(self):
        return RisksetCounter(self.ids.copy(), self.time_start.copy(), self.time_stop.copy(), self.event.copy())

    def __reduce__(self):
        return (self.__class__, (self.ids, self.time_start, self.time_stop, self.event))

# Let's check if the updated RisksetCounter works
riskset_test = RisksetCounter(ids, time_start, time_stop, event)
riskset_test.n_at_risk, riskset_test.n_events, len(riskset_test.n_at_risk), len(riskset_test.n_events)


In [None]:
import numpy as np
import pandas as pd


path="/Users/jeongsookim/Downloads"
simuDat= pd.read_csv(f"{path}/simuDat.csv")
simuDat

In [None]:
riskset_counter = RisksetCounter(ids=simuDat["id"].values,
                                       time_start=simuDat["start"].values,
                                       time_stop=simuDat["stop"].values,
                                       event=simuDat["event"].values)

#### Group==Contr

In [None]:
simuDat_male = simuDat[simuDat["group"] == 0]

In [None]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

In [None]:
riskset_counter.n_events, len(riskset_counter.n_events)

In [None]:
np.unique(riskset_counter.ids)

In [None]:
riskset_counter.reset()

In [None]:
simuDat_male = simuDat[simuDat["group"] == 1]

In [None]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

In [None]:
riskset_counter.n_events, len(riskset_counter.n_events)

In [None]:
np.unique(riskset_counter.ids)

#### Male

In [None]:
riskset_counter.reset()

In [None]:
simuDat_male = simuDat[simuDat["gender"] == 0]

In [None]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

In [None]:
riskset_counter.n_events, len(riskset_counter.n_events)

In [None]:
np.unique(riskset_counter.ids)

In [None]:
np.unique(simuDat_male['id'].values)

#### Female

In [None]:
riskset_counter.reset()

In [None]:
simuDat_male = simuDat[simuDat["gender"] == 1]

In [None]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
riskset_counter.n_events, len(riskset_counter.n_events)


In [None]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

In [None]:
riskset_counter.n_events, len(riskset_counter.n_events)


In [None]:
np.unique(riskset_counter.ids)

In [None]:
np.unique(simuDat_male['id'].values)

#### Contr&male

In [None]:
simuDat_sub1 = simuDat[(simuDat["group"] == 0)&(simuDat['gender']==0)]

In [None]:
riskset_counter.reset()

In [None]:
np.unique(riskset_counter.ids)

In [None]:
simuDat_male = simuDat[simuDat["group"] == 0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
simuDat_male = simuDat[simuDat['gender'] == 0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
np.unique(riskset_counter.ids)

In [None]:
np.unique(simuDat_sub1['id'].values)

#### Contr & female

In [None]:
simuDat_sub1 = simuDat[(simuDat["group"] == 0)&(simuDat['gender']==1)]

In [None]:
riskset_counter.reset()

In [None]:
np.unique(riskset_counter.ids)

In [None]:
simuDat_male = simuDat[simuDat["group"] == 0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
simuDat_male = simuDat[simuDat["gender"] == 1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
np.unique(riskset_counter.ids)

In [None]:
np.unique(simuDat_sub1['id'].values)

#### Treat&male

In [None]:
simuDat_sub1 = simuDat[(simuDat["group"] == 1)&(simuDat['gender']==0)]

In [None]:
riskset_counter.reset()

In [None]:
simuDat_male = simuDat[simuDat['group']==1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
simuDat_male = simuDat[simuDat['gender']==0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
np.unique(riskset_counter.ids)

In [None]:
np.unique(simuDat_sub1['id'].values)

#### Treat&female

In [None]:
simuDat_sub1 = simuDat[(simuDat["group"] == 1)&(simuDat['gender']==1)]

In [None]:
riskset_counter.reset()

In [None]:
simuDat_male = simuDat[simuDat['group']==1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
simuDat_male = simuDat[simuDat['gender']==1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [None]:
np.unique(riskset_counter.ids)

In [None]:
np.unique(simuDat_sub1['id'].values)

In [None]:
def argbinsearch(arr, key_val):
    arr_len = len(arr)
    min_idx = 0
    max_idx = arr_len

    while min_idx < max_idx:
        mid_idx = min_idx + ((max_idx - min_idx) // 2)

        if mid_idx < 0 or mid_idx >= arr_len:
            return -1

        mid_val = arr[mid_idx]
        if mid_val <= key_val:  # Change the condition to <=
            min_idx = mid_idx + 1
        else:
            max_idx = mid_idx

    return min_idx

이 함수는 argbinsearch라는 이름의 함수로, 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 이진 탐색으로 찾아 반환합니다.

자세한 코드 설명을 아래에 제공합니다:

1. 입력:

  * arr: 탐색 대상인 정렬된 배열
  key_val: 찾고자 하는 키 값

2. 초기 변수 설정:

  * arr_len: 배열의 길이를 저장합니다.
  * min_idx: 탐색 범위의 최솟값으로, 처음에는 배열의 시작 인덱스인 0으로 설정됩니다.
  * max_idx: 탐색 범위의 최댓값으로, 처음에는 배열의 길이로 설정됩니다.

3. 이진 탐색:

  * while 루프를 사용하여 min_idx가 max_idx보다 작은 동안 탐색을 반복합니다.
  * mid_idx: 현재 탐색 범위의 중간 인덱스를 계산합니다.
  * mid_val: 중간 인덱스에 해당하는 배열의 원소 값을 가져옵니다.

4. 키 값과 중간 값을 비교합니다:
  * 만약 중간 값이 키 값보다 작거나 같으면, min_idx를 mid_idx + 1로 업데이트합니다. 이렇게 하면 탐색 범위의 왼쪽 부분을 제외하게 됩니다.
  * 그렇지 않으면, max_idx를 mid_idx로 업데이트합니다. 이렇게 하면 탐색 범위의 오른쪽 부분을 제외하게 됩니다.

5. 결과 반환:

  * 루프가 종료되면, min_idx는 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 가리키게 됩니다. 따라서 min_idx를 반환합니다.

이 함수는 정렬된 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 위치를 효율적으로 찾기 위해 사용됩니다. 이진 탐색은 배열의 중간 값을 반복적으로 확인하면서 탐색 범위를 절반씩 줄여나가므로, 큰 배열에서도 빠르게 원하는 값을 찾을 수 있습니다.

### PseudoScoreCriterion

In [None]:
class PseudoScoreCriterion:
    def __init__(self, n_outputs, n_samples, unique_times, x, ids, time_start, time_stop, event):
        """
        Constructor of the class
        Initialize instance variables using the provided input parameters
        Objects 'riskset_left', 'riskset_right', and 'riskset_total' are initialized using the 'RisksetCounter' class
        """
        self.n_outputs = n_outputs
        self.n_samples = n_samples
        self.unique_times = unique_times
        self.x = x
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.unique_ids = set(self.ids)  # Store unique ids for later use
        self.unique_times = unique_times

        self.riskset_left = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_right = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_total = RisksetCounter(ids, time_start, time_stop, event)

        self.samples_time_idx = np.searchsorted(unique_times, time_stop)

        self.split_pos = 0
        self.split_time_idx = 0

        self._riskset_counter = RisksetCounter(ids, time_start, time_stop, event)  # 새로 추가

    def init(self, y, sample_weight, n_samples, samples, start, end):
        """
        Initialization function
        Reset the risk set counters ('riskset_left','riskset_right','riskset_total') and updates 'riskset_total' with new data
        """
        self.samples = samples
        self.riskset_left.reset()
        self.riskset_right.reset()
        self.riskset_total.reset()

        time_starts, stop_times, events = y[:, 0], y[:, 1], y[:, 2]
        ids_for_update = [self.ids[idx] for idx in samples[start:end]]
        time_starts_for_update = [time_starts[idx] for idx in samples[start:end]]
        stop_times_for_update = [stop_times[idx] for idx in samples[start:end]]
        events_for_update = [events[idx] for idx in samples[start:end]]

        # Combine unique times from both datasets
        self.unique_times = np.unique(np.concatenate([self.unique_times, stop_times_for_update]))

        self.riskset_total.update(ids_for_update, time_starts_for_update, stop_times_for_update, events_for_update)

    def set_unique_times(self, unique_times):
        """Sets the unique times for the current node."""
        self.unique_times = unique_times

## Group Indicator만으로 나누기...

    # Functions returning the risk set value and event value for the given ID and time index from the respective risk set (left or right)
    def Y_left_value(self, id_, t):
        return self.riskset_left.Y_i(id_, t)

    def Y_right_value(self, id_, t):
        return self.riskset_right.Y_i(id_, t)

    def dN_bar_left_value(self, id_, t):
        return self.riskset_left.dN_bar_i(id_, t)

    def dN_bar_right_value(self, id_, t):
        return self.riskset_right.dN_bar_i(id_, t)

    def calculate_variance_estimate(self):
        """
        Functions to compute the variance estimate for the split
        """
        n_unique_times = len(self.unique_times)

        left_n_at_risk = np.pad(self.riskset_left.n_at_risk, (0, n_unique_times - len(self.riskset_left.n_at_risk)), 'edge')
        right_n_at_risk = np.pad(self.riskset_right.n_at_risk, (0, n_unique_times - len(self.riskset_right.n_at_risk)), 'edge')

        left_n_events = np.pad(self.riskset_left.n_events, (0, n_unique_times - len(self.riskset_left.n_events)), 'constant')
        right_n_events = np.pad(self.riskset_right.n_events, (0, n_unique_times - len(self.riskset_right.n_events)), 'constant')

        w = (left_n_at_risk * right_n_at_risk) / (left_n_at_risk + right_n_at_risk)

        w_expanded = np.tile(w, len(self.unique_ids))
        left_n_at_risk_expanded = np.tile(left_n_at_risk, len(self.unique_ids))
        right_n_at_risk_expanded = np.tile(right_n_at_risk, len(self.unique_ids))

        Y_left, Y_right, term_left, term_right = [], [], [], []

        for id_ in self.unique_ids:
            for t in range(n_unique_times):  # Ensure we don't exceed the bounds of the array
                Y_left_val = self.Y_left_value(id_, t)
                Y_right_val = self.Y_right_value(id_, t)

                dN_bar_left_val = self.dN_bar_left_value(id_, t)
                dN_bar_right_val = self.dN_bar_right_value(id_, t)

                term_left_val = (dN_bar_left_val - (left_n_events[t] / left_n_at_risk[t])) ** 2
                term_right_val = (dN_bar_right_val - (right_n_events[t] / right_n_at_risk[t])) ** 2

                Y_left.append(Y_left_val)
                Y_right.append(Y_right_val)
                term_left.append(term_left_val)
                term_right.append(term_right_val)

        Y_left = np.array(Y_left)
        Y_right = np.array(Y_right)
        term_left = np.array(term_left)
        term_right = np.array(term_right)

        var_estimate_L = np.sum(w_expanded * (Y_left / left_n_at_risk_expanded) * term_left)
        var_estimate_R = np.sum(w_expanded * (Y_right / right_n_at_risk_expanded) * term_right)

        return var_estimate_L + var_estimate_R

    def proxy_impurity_improvement(self):
        """
        Functions that calculates the pseudo impurity improvement of the split.
        This value represents the reduction in pseudo impurity in the risk sets after the split.
        """
        # If either risk set is empty, return -np.inf
        if len(self.riskset_left.n_at_risk) == 0 or len(self.riskset_right.n_at_risk) == 0:
            return -np.inf

        left_n_at_risk = np.pad(self.riskset_left.n_at_risk,
                                (0, max(0, len(self.unique_times) - len(self.riskset_left.n_at_risk))),
                                'edge')
        right_n_at_risk = np.pad(self.riskset_right.n_at_risk,
                                 (0, max(0, len(self.unique_times) - len(self.riskset_right.n_at_risk))),
                                 'edge')

        w = (left_n_at_risk * right_n_at_risk) / (left_n_at_risk + right_n_at_risk)

        left_n_events = np.pad(self.riskset_left.n_events,
                               (0, max(0, len(self.unique_times) - len(self.riskset_left.n_events))),
                               'constant')
        right_n_events = np.pad(self.riskset_right.n_events,
                                (0, max(0, len(self.unique_times) - len(self.riskset_right.n_events))),
                                'constant')

        term = (left_n_events / left_n_at_risk) - (right_n_events / right_n_at_risk)
        numer = np.sum(w * term)
        var_estimate = self.calculate_variance_estimate()

        return numer / (np.sqrt(var_estimate) + 1e-7)

    def update_riskset(self, ids_subset):
        # Update the riskset based on the subset of IDs at the current node
        unique_ids_subset = np.unique(ids_subset)
        self.riskset_counter.update(unique_ids_subset, self.time_start, self.time_stop, self.event)

    def node_value(self):
        """
        Returns the Nelson-Aalen estimator of the mean function μ(t) for the entities in the current node.
        """
        return self.node_value_from_riskset(self.riskset_total)

    def node_value_from_riskset(self, riskset_counter):
        """
        Returns the Nelson-Aalen estimator of the mean function μ(t) for the entities based on provided riskset_counter.
        """
        mu_hat_values = []

        # Initialize the cumulative sum of the Nelson-Aalen estimator
        cumsum_Nelson_Aalen = 0

        for t_idx, t in enumerate(self.unique_times):
            # Use n_at_risk and n_events from the riskset_counter
            n_at_risk_t = riskset_counter.n_at_risk[t_idx] if t_idx < len(riskset_counter.n_at_risk) else 0
            n_events_t = riskset_counter.n_events[t_idx] if t_idx < len(riskset_counter.n_events) else 0

            cumsum_Nelson_Aalen += n_events_t / (n_at_risk_t + 1e-7)  # Avoiding division by zero
            mu_hat_values.append(cumsum_Nelson_Aalen)

        return mu_hat_values

    # RisksetCounter의 상태를 저장하고 복원하기 위한 메서드를 추가합니다.
    def save_riskset_state(self):
        self._riskset_counter.save_state()

    def reset_riskset_state(self):
        self._riskset_counter.reset()

    def reset(self):
        """
        Functions to reset all risk set counters
        """
        self.riskset_total.reset()
        self.riskset_left.reset()
        self.riskset_right.reset()

    def copy(self):
        """
        Creates and returns a copy of the current object.
        """
        new_criterion = PseudoScoreCriterion(self.n_outputs, self.n_samples, self.unique_times,
                                             self.x, self.ids, self.time_start, self.time_stop,
                                             self.event)
        new_criterion.riskset_left = self.riskset_left.copy()
        new_criterion.riskset_right = self.riskset_right.copy()
        new_criterion.riskset_total = self.riskset_total.copy()
        new_criterion.samples_time_idx = self.samples_time_idx.copy()
        if hasattr(self, 'samples'):
            new_criterion.samples = self.samples.copy()

        return new_criterion

def update_with_group_indicator(self, feature_index, group_indicator):
    """
    Update the criterion based on a specified feature and group indicator.
    This will split the data into left and right nodes based on the provided feature and group indicator.
    """
    # Reset the riskset counters for the left and right nodes
    self.riskset_left.reset()
    self.riskset_right.reset()

    # Determine the split by the feature and group indicator
    left_mask = self.x[:, feature_index] <= group_indicator  # Changed to <= for continuous features
    right_mask = ~left_mask

    # Create empty lists to store the ids, start times, stop times, and events for both left and right splits
    ids_left, start_left, stop_left, event_left = [], [], [], []
    ids_right, start_right, stop_right, event_right = [], [], [], []

    # For each unique ID, decide whether to assign it to the left or right node based on the mask
    for id_ in self.unique_ids:
        id_indices = np.where(self.ids == id_)[0]  # Get all indices for this ID
        if left_mask[id_indices[0]]:
            ids_left.extend([self.ids[i] for i in id_indices])
            start_left.extend([self.time_start[i] for i in id_indices])
            stop_left.extend([self.time_stop[i] for i in id_indices])
            event_left.extend([self.event[i] for i in id_indices])
        else:
            ids_right.extend([self.ids[i] for i in id_indices])
            start_right.extend([self.time_start[i] for i in id_indices])
            stop_right.extend([self.time_stop[i] for i in id_indices])
            event_right.extend([self.event[i] for i in id_indices])

    # Set the all_unique_times for the risk sets of left and right nodes to the current node's unique times
    self.riskset_left.all_unique_times = self.unique_times
    self.riskset_right.all_unique_times = self.unique_times

    # Also, adjust the lengths of n_at_risk and n_events in both riskset_left and riskset_right to match unique_times
    self.riskset_left.n_at_risk = np.zeros(len(self.unique_times), dtype=np.int64)
    self.riskset_left.n_events = np.zeros(len(self.unique_times), dtype=np.int64)
    self.riskset_right.n_at_risk = np.zeros(len(self.unique_times), dtype=np.int64)
    self.riskset_right.n_events = np.zeros(len(self.unique_times), dtype=np.int64)

    # Update the risk sets for the left and right nodes
    self.riskset_left.update(ids_left, start_left, stop_left, event_left)
    self.riskset_right.update(ids_right, start_right, stop_right, event_right)

# 이 함수를 PseudoScoreCriterion 클래스에 추가합니다.
setattr(PseudoScoreCriterion, 'update', update_with_group_indicator)

# 추가로, left node와 right node의 데이터를 반환하는 메소드를 추가합니다.
def get_left_node_data(self):
    return self.riskset_left.ids, self.riskset_left.n_at_risk, self.riskset_left.n_events

def get_right_node_data(self):
    return self.riskset_right.ids, self.riskset_right.n_at_risk, self.riskset_right.n_events

setattr(PseudoScoreCriterion, 'get_left_node_data', get_left_node_data)
setattr(PseudoScoreCriterion, 'get_right_node_data', get_right_node_data)

def calculate_node_value_updated(self, side="left"):
    """
    Calculate the node value based on the updated RisksetCounter using get_left_node_data and get_right_node_data.

    Parameters:
        - side (str): Either "left" or "right" to determine which riskset to use for calculation.
    """
    if side == "left":
        ids, n_at_risk, n_events = self.get_left_node_data()
    elif side == "right":
        ids, n_at_risk, n_events = self.get_right_node_data()
    else:
        raise ValueError("Invalid side value. Expected 'left' or 'right'.")

    mask = np.isin(self.ids, ids)

    time_start_filtered = self.time_start[mask]
    time_stop_filtered = self.time_stop[mask]
    event_filtered = self.event[mask]

    riskset_temp = RisksetCounter(ids, time_start_filtered, time_stop_filtered, event_filtered)
    riskset_temp.n_at_risk = n_at_risk
    riskset_temp.n_events = n_events

    return self.node_value_from_riskset(riskset_temp)

# PseudoScoreCriterion 클래스에 위에서 정의한 함수를 추가합니다.
setattr(PseudoScoreCriterion, 'calculate_node_value', calculate_node_value_updated)



PseudoScoreCriterion


In [None]:
riskset_counter.reset()

In [None]:
n_samples = len(ids)

# Create an instance of the RisksetCounter and PseudoScoreCriterion classes
riskset = RisksetCounter(ids, time_start, time_stop, event)
criterion = PseudoScoreCriterion(n_outputs=1, n_samples=n_samples, unique_times=np.unique(time_stop), x=x, ids=ids, time_start=time_start, time_stop=time_stop, event=event)

criterion

In [None]:
feature_index = 0
group_indicator = 0

criterion.update(feature_index, group_indicator)

In [None]:
criterion.get_left_node_data()

In [None]:
criterion.get_right_node_data()

In [None]:
criterion.proxy_impurity_improvement()

In [None]:
criterion.calculate_node_value('left')

In [None]:
criterion.calculate_node_value('right')

In [None]:
riskset_counter.reset()

n_samples = len(ids)

# Create an instance of the RisksetCounter and PseudoScoreCriterion classes
riskset = RisksetCounter(ids, time_start, time_stop, event)
criterion = PseudoScoreCriterion(n_outputs=1, n_samples=n_samples, unique_times=np.unique(time_stop), x=x, ids=ids, time_start=time_start, time_stop=time_stop, event=event)

criterion

In [None]:
feature_index=2
group_indicator = 0

criterion.update(feature_index, group_indicator)

In [None]:
criterion.get_left_node_data()

In [None]:
criterion.get_right_node_data()

In [None]:
criterion.proxy_impurity_improvement()

In [None]:
criterion.node_value()

In [None]:
criterion.calculate_node_value('left')

In [None]:
criterion.calculate_node_value('right')

### 이전 버전에서 수정한 것!!!

In [None]:
import pandas as pd
import numpy as np
from sklearn.utils import check_random_state

class PseudoScoreTreeBuilder:
    """
    Class designed to build a decision tree based on the pseudo-score test statistics criterion,
    typically used in recurrent events data analysis.
    """

    TREE_UNDEFINED = -1  # Placeholder

    def __init__(self, max_depth=None, min_ids_split=2, min_ids_leaf=1,
                 max_features=None, max_thresholds=None, min_impurity_decrease=0,
                 random_state=None):
        self.max_depth = max_depth
        self.min_ids_split = min_ids_split
        self.min_ids_leaf = min_ids_leaf
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.min_impurity_decrease = min_impurity_decrease
        self.random_state = check_random_state(random_state)

    def split_indices(self, X_column, threshold, criterion, start, end):
        """Efficiently splits the data based on the given threshold for a specific feature column."""
        left_indices = np.where(X_column <= threshold)[0]
        right_indices = np.where(X_column > threshold)[0]

        # Convert local indices to global indices
        left_indices = np.arange(start, end)[left_indices]
        right_indices = np.arange(start, end)[right_indices]

        return left_indices, right_indices

    def _split(self, X, criterion, start, end):
        best_split = {
            'feature_index': None,
            'threshold': None,
            'improvement': -np.inf
        }

        n_features = X.shape[1]

        for feature_index in range(n_features):
            unique_thresholds = np.unique(X[start:end, feature_index])
            if len(unique_thresholds) <= 1:
                continue  # Skip if the feature has a single unique value

            # If max_thresholds specified and unique_thresholds is larger, randomly sample
            if self.max_thresholds and len(unique_thresholds) > self.max_thresholds:
                unique_thresholds = self.random_state.choice(unique_thresholds, self.max_thresholds, replace=False)

            for threshold in unique_thresholds:
                criterion.update(feature_index, threshold)
                improvement = criterion.proxy_impurity_improvement()

                if improvement > best_split['improvement']:
                    best_split = {
                        'feature_index': feature_index,
                        'threshold': threshold,
                        'improvement': improvement
                    }

        return best_split

    def _build(self, X, y, criterion, depth=0, start=0, end=None):
        if end is None:
            end = X.shape[0]

        ids = y[start:end, 0]
        unique_ids = np.unique(ids)

        # Initialize RisksetCounter for the current node and compute the n_at_risk and n_events arrays
        riskset_counter = RisksetCounter(ids, y[start:end, 1], y[start:end, 2], y[start:end, 3])
        node_value = criterion.node_value_from_riskset(riskset_counter)

        # Get unique times for the current node
        node_unique_times = riskset_counter.all_unique_times.tolist()

        # Adjust node_value length based on unique times of the node
        node_value = node_value[:len(node_unique_times)]

        # If current depth is equal to or greater than max_depth, stop further splits
        if self.max_depth is not None and depth >= self.max_depth:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': node_unique_times,
                'ids': unique_ids.tolist()
            }

        # If the number of unique IDs in the current node is less than min_ids_split, stop further splits
        if len(unique_ids) < self.min_ids_split:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': node_unique_times,
                'ids': unique_ids.tolist()
            }

        best_split = self._split(X, criterion, start, end)

        # If improvement is less than min_impurity_decrease, stop further splits
        if best_split['improvement'] < self.min_impurity_decrease:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': node_unique_times,
                'ids': unique_ids.tolist()
            }

        left_indices, right_indices = self.split_indices(X[start:end, best_split['feature_index']], best_split['threshold'], criterion, start, end)

        # Check if there's data in both the left and right child nodes
        if len(left_indices) == 0 or len(right_indices) == 0:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': node_unique_times,
                'ids': unique_ids.tolist()
            }

        # Build the left child
        left_child = self._build(X[left_indices], y[left_indices], criterion, depth=depth+1)

        # Build the right child
        right_child = self._build(X[right_indices], y[right_indices], criterion, depth=depth+1)

        return {
            'feature': best_split['feature_index'],
            'threshold': best_split['threshold'],
            'left_child': left_child,
            'right_child': right_child,
            'node_value': node_value,
            'unique_times': node_unique_times,
            'ids': unique_ids.tolist()
        }

    def build(self, X, ids, time_start, time_stop, event):
        """
        The main method to invoke the tree building process.
        Initializes the pseudo-likelihood criterion using the input data and constructs the tree using the _build method.
        """
        n_samples, n_features = X.shape
        y = np.c_[ids, time_start, time_stop, event]

        unique_times = np.unique(np.concatenate([time_start, time_stop]))
        criterion = PseudoScoreCriterion(n_outputs=1, n_samples=n_samples,
                                         unique_times=unique_times, x=X, ids=ids,
                                         time_start=time_start, time_stop=time_stop, event=event)

        tree = self._build(X, y, criterion)
        return tree


In [None]:
data

In [None]:
ids = data['id'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
x = data[['gender','group']].values

In [None]:
# Initialize and build the tree using PseudoScoreTreeBuilder
tree_builder = PseudoScoreTreeBuilder(max_depth=2, min_ids_leaf=10, random_state=1190)

In [None]:
tree_df = tree_builder.build(x, ids, time_start, time_stop, event)

# Display the tree dataframe
tree_df

In [None]:
class RecurrentTree:
    def __init__(self, max_depth=None, min_ids_split=2, min_ids_leaf=1,
                 max_features=None, max_thresholds=None, min_impurity_decrease=0,
                 random_state=None):
        """
        Constructor of the class
        Initializes the tree's hyperparameters and settings
        """
        self.max_depth = max_depth
        self.min_ids_split = min_ids_split
        self.min_ids_leaf = min_ids_leaf
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.min_impurity_decrease = min_impurity_decrease
        self.random_state = random_state
        self.tree_ = None

    def fit(self, X, ids, time_start, time_stop, event):
        """
        Trains the recurrent tree using the input data
        """
        X = np.array(X)
        ids = np.array(ids)
        time_start = np.array(time_start)
        time_stop = np.array(time_stop)
        event = np.array(event)

        # Use the PseudoScoreTreeBuilder to build the tree
        builder = PseudoScoreTreeBuilder(
            max_depth=self.max_depth,
            min_ids_split=self.min_ids_split,
            min_ids_leaf=self.min_ids_leaf,
            max_features=self.max_features,
            max_thresholds=self.max_thresholds,
            min_impurity_decrease=self.min_impurity_decrease,
            random_state=self.random_state
        )
        self.tree_ = builder.build(X, ids, time_start, time_stop, event)
        return self

    def get_tree(self):
        """Return the tree as a dictionary."""
        return self.tree_

    def _traverse_tree(self, x, node):
        """Traverse the tree to find the terminal node for a given sample."""
        if node["feature"] is None:
            return node
        if x[node["feature"]] <= node["threshold"]:
            return self._traverse_tree(x, node["left_child"])
        else:
            return self._traverse_tree(x, node["right_child"])

    def predict_mean_function(self, X, ids):
        """
        Predict the node_value of the terminal node for given samples.
        """
        X = np.array(X)
        ids = np.array(ids)

        mean_functions = [self._traverse_tree(sample, self.tree_)["node_value"] for sample in X]
        return mean_functions

    def predict_rate_function(self, X, ids):
        """
        Predict the rate function as the difference between unique time points for the terminal node.
        """
        mean_functions = self.predict_mean_function(X, ids)
        rate_functions = []

        for func in mean_functions:
            # Calculate rate function as the difference between consecutive time points
            rate_function = np.diff(func, prepend=func[0])
            rate_functions.append(rate_function)

        return rate_functions

    def apply(self, X, ids):
        """Return the index of the leaf that each unique ID is predicted as."""
        X = np.array(X, dtype=np.float32)
        ids = np.array(ids)

        unique_ids = np.unique(ids)
        leaf_indices = []

        for uid in unique_ids:
            idx = np.where(ids == uid)[0]
            representative_sample = X[idx[0]]  # Using the first sample of the group
            leaf_index = self._get_leaf_index(representative_sample, self.tree_)
            leaf_indices.append(leaf_index)

        return np.array(leaf_indices)

    def _get_leaf_index(self, x, node, current_index=0):
        """Traverse the tree to find the leaf index for a given sample."""
        if node["feature"] is None:
            return current_index
        if x[node["feature"]] <= node["threshold"]:
            return self._get_leaf_index(x, node["left_child"], current_index*2 + 1)
        else:
            return self._get_leaf_index(x, node["right_child"], current_index*2 + 2)


### RecurrentTree V2.

In [None]:
class RecurrentTree:
    def __init__(self, max_depth=None, min_ids_split=2, min_ids_leaf=1,
                 max_features=None, max_thresholds=None, min_impurity_decrease=0,
                 random_state=None):
        """
        Constructor of the class
        Initializes the tree's hyperparameters and settings
        """
        self.max_depth = max_depth
        self.min_ids_split = min_ids_split
        self.min_ids_leaf = min_ids_leaf
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.min_impurity_decrease = min_impurity_decrease
        self.random_state = random_state
        self.tree_ = None

    def fit(self, X, ids, time_start, time_stop, event):
        """
        Trains the recurrent tree using the input data
        """
        X = np.array(X)
        ids = np.array(ids)
        time_start = np.array(time_start)
        time_stop = np.array(time_stop)
        event = np.array(event)

        # Use the PseudoScoreTreeBuilder to build the tree
        builder = PseudoScoreTreeBuilder(
            max_depth=self.max_depth,
            min_ids_split=self.min_ids_split,
            min_ids_leaf=self.min_ids_leaf,
            max_features=self.max_features,
            max_thresholds=self.max_thresholds,
            min_impurity_decrease=self.min_impurity_decrease,
            random_state=self.random_state
        )
        self.tree_ = builder.build(X, ids, time_start, time_stop, event)
        return self

    def get_tree(self):
        """Return the tree as a dictionary."""
        return self.tree_

    def traverse_tree_for_id(self, X_id_samples, node):
        """
        Traverse the tree for a specific ID based on its samples.

        Args:
        - X_id_samples (list of arrays): The samples corresponding to a specific ID.
        - node (dict): The current node being evaluated in the tree.

        Returns:
        - node (dict): The terminal node for the specific ID.
        """
        if node["feature"] is None:  # Terminal node
            return node

        # Traverse the tree for each sample and collect the terminal nodes
        terminal_nodes = []
        for sample in X_id_samples:
            if sample[node["feature"]] <= node["threshold"]:
                terminal_nodes.append(self.traverse_tree_for_id([sample], node["left_child"]))
            else:
                terminal_nodes.append(self.traverse_tree_for_id([sample], node["right_child"]))

        # Check if all samples lead to the same terminal node
        first_terminal = terminal_nodes[0]
        if all(node == first_terminal for node in terminal_nodes):
            return first_terminal

        # If samples lead to different terminal nodes, it's ambiguous. For simplicity, return the first terminal node.
        # In a real-world scenario, this might need more sophisticated handling.
        return first_terminal

    def predict_mean_function(self, X, ids):
        """
        Predict the node_value of the terminal node for given samples.
        """

        # Ensure X is a list of samples
        X = np.array(X)

        mean_function_predictions = {}

        for sample_id in ids:
            samples_for_id = [X[i] for i, uid in enumerate(ids) if uid == sample_id]
            terminal_node_for_id = self.traverse_tree_for_id(samples_for_id, self.tree_)

            mean_function_predictions[sample_id] = {
                "unique_times": terminal_node_for_id.get('unique_times', []),
                "mean_function": terminal_node_for_id["node_value"]
            }

        return mean_function_predictions

    def predict_rate_function(self, X, ids):
        """
        Predict the rate function as the difference between unique time points for the terminal node.
        """

        if np.isscalar(ids):
            ids = [ids]

        mean_function_predictions = self.predict_mean_function(X, ids)

        rate_function_predictions = {}
        for sample_id in ids:
            unique_times = mean_function_predictions[sample_id]["unique_times"]
            mean_function_values = mean_function_predictions[sample_id]["mean_function"]

            # Calculate rate function as the difference between consecutive mean function values
            rate_function = np.diff(mean_function_values, prepend=mean_function_values[0])

            rate_function_predictions[sample_id] = {
                "times": unique_times,
                "rates": rate_function
            }

        return rate_function_predictions

    def _map_terminal_nodes(self, node):
        """
        Recursively traverse the tree and assign unique integers to each terminal node.
        """
        if node["feature"] is None:  # Terminal node
            if __builtins__.id(node) not in self.terminal_node_mapping:
                self.terminal_node_mapping[__builtins__.id(node)] = len(self.terminal_node_mapping)
            return

        self._map_terminal_nodes(node["left_child"])
        self._map_terminal_nodes(node["right_child"])

    def apply(self, X, ids=None):
        """Return the index of the leaf that each unique ID is predicted as."""
        X = np.array(X, dtype=np.float32)
        if ids is None:
            ids = np.array([i for i in range(X.shape[0])])
        else:
            ids = np.array(ids)

        terminal_nodes = {}
        self.terminal_node_mapping = {}  # Reset the mapping
        self._map_terminal_nodes(self.tree_)

        for sample_id in np.unique(ids):
            samples_for_id = [X[i] for i, uid in enumerate(ids) if uid == sample_id]
            terminal_node_for_id = self.traverse_tree_for_id(samples_for_id, self.tree_)
            terminal_nodes[sample_id] = self.terminal_node_mapping[__builtins__.id(terminal_node_for_id)]

        return terminal_nodes

In [None]:
x = data[['group','gender']].values

In [None]:
# 2. RecurrentTree 학습 및 예측
tree_model = RecurrentTree(max_depth=2, random_state=1190)
tree_model.fit(x, ids, time_start, time_stop, event)

In [None]:
%pip install graphviz

In [None]:
tree_model.get_tree()

In [None]:
tree = tree_model.get_tree()
tree

In [None]:
predict_mean_function=tree_model.predict_mean_function(x,ids=ids)
print("ID 1 predicted mean function:",predict_mean_function[1])
print("ID 2 predicted mean function:",predict_mean_function[2])
print("ID 26 predicted mean function:",predict_mean_function[26])
print("ID 27 predicted mean function:",predict_mean_function[27])

In [None]:
tree_model.apply(x,ids=ids)

In [None]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter

        if node is None:
            return

        node_name = f"node{node_counter}"
        node_counter += 1

        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = "\n".join([f"t{idx}: {value:.2f}" for idx, value in enumerate(node['node_value'])])
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")

        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph


In [None]:
dot = visualize_tree_simple(tree)
dot.view()

In [None]:
def visualize_tree_with_data(tree):
    """
    Visualize the tree using graphviz.
    This function displays the structure of the tree, the threshold values for each node,
    and the unique IDs at leaf nodes.
    """
    import graphviz

    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter

        if node is None:
            return

        node_name = f"node{node_counter}"
        node_counter += 1

        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = f"Unique IDs: {node['ids']}"
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)

            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")

        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph




In [None]:
dot = visualize_tree_with_data(tree)
dot.view()

In [None]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter

        if node is None:
            return

        node_name = f"node{node_counter}"
        node_counter += 1

        # If the current node is a leaf node
        if node["feature"] is None:
            graph.node(node_name, label="Leaf Node", shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")

        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph


In [None]:
dot = visualize_tree_simple(tree)
dot.view()

### Mission Clear

In [None]:
x = data[['group','x1','gender']].values
# 2. RecurrentTree 학습 및 예측
tree_model = RecurrentTree(max_depth=3, random_state=1190)
tree_model.fit(x, ids, time_start, time_stop, event)

In [None]:
tree=tree_model.get_tree()

In [None]:
predict_mean_function=tree_model.predict_mean_function(x,ids)
predict_mean_function[64]

In [None]:
tree_model.apply(x,ids)

In [None]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter

        if node is None:
            return

        node_name = f"node{node_counter}"
        node_counter += 1

        # If the current node is a leaf node
        if node["feature"] is None:
            graph.node(node_name, label="Leaf Node", shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")

        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph


In [None]:
dot = visualize_tree_simple(tree)
dot.view()

In [None]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter

        if node is None:
            return

        node_name = f"node{node_counter}"
        node_counter += 1

        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = "\n".join([f"t{idx}: {value:.2f}" for idx, value in enumerate(node['node_value'])])
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")

        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph


In [None]:
dot = visualize_tree_simple(tree)
dot.view()

In [None]:
def visualize_tree_with_data(tree):
    """
    Visualize the tree using graphviz.
    This function displays the structure of the tree, the threshold values for each node,
    and the unique IDs at leaf nodes.
    """
    import graphviz

    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter

        if node is None:
            return

        node_name = f"node{node_counter}"
        node_counter += 1

        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = f"Unique IDs: {node['ids']}"
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)

            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")

        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph

In [None]:
dot = visualize_tree_with_data(tree)
dot.view()

### RecurrentRandomForest

In [None]:
import numpy as np
from numbers import Integral, Real
from sklearn.utils import check_random_state

def _get_n_samples_bootstrap(n_ids, max_samples):
    """
    Modified for recurrent events. Get the number of IDs in a bootstrap sample.
    """
    if max_samples is None:
        return n_ids

    if isinstance(max_samples, Integral):
        if max_samples > n_ids:
            msg = "`max_samples` must be <= n_ids={} but got value {}"
            raise ValueError(msg.format(n_ids, max_samples))
        return max_samples

    if isinstance(max_samples, Real):
        return max(round(n_ids * max_samples), 1)

def _generate_sample_indices(random_state, ids, n_ids_bootstrap):
    """
    Sample unique IDs and then expand to all associated events.
    """
    random_instance = check_random_state(random_state)
    sampled_ids = np.random.choice(ids, n_ids_bootstrap, replace=True)
    return sampled_ids

def _generate_unsampled_indices(random_state, ids, n_ids_bootstrap):
    """
    Determine unsampled IDs and then expand to all associated events.
    """
    unique_ids = np.unique(ids)
    sampled_ids = _generate_sample_indices(random_state, unique_ids, n_ids_bootstrap)
    unique_sampled_ids = np.unique(sampled_ids)  # Ensure sampled IDs are unique
    unsampled_ids = np.setdiff1d(unique_ids, unique_sampled_ids)

    # Expand these unsampled IDs to include all their associated events.
    unsampled_indices = np.where(np.isin(ids, unsampled_ids))[0]

    return unsampled_indices

from warnings import catch_warnings, simplefilter
from sklearn.utils.class_weight import compute_sample_weight

def _parallel_build_trees(
    tree,
    bootstrap,
    X,
    y,  # Now, y is expected to be a dictionary with 'id', 'time_start', 'time_stop', and 'event' as keys
    sample_weight,
    tree_idx,
    n_trees,
    verbose=0,
    class_weight=None,
    n_ids_bootstrap=None
):
    """
    Private function used to fit a single tree in parallel for recurrent events.
    """
    # Extract necessary data from y
    ids = y['id']
    time_start = y['time_start']
    time_stop = y['time_stop']
    event = y['event']

    if verbose > 1:
        print("building tree %d of %d" % (tree_idx + 1, n_trees))

    if bootstrap:
        unique_ids = np.unique(ids)
        n_ids = len(unique_ids)

        # Generate bootstrap samples using IDs
        sampled_ids = _generate_sample_indices(
            tree.random_state, unique_ids, n_ids_bootstrap
        )

        # Expand sampled IDs to all their associated events
        indices = np.where(np.isin(ids, sampled_ids))[0]

        if sample_weight is None:
            curr_sample_weight = np.ones((X.shape[0],), dtype=np.float64)
        else:
            curr_sample_weight = sample_weight.copy()

        # Adjust the sample weight based on how many times each ID was sampled
        sample_counts_for_ids = np.bincount(np.searchsorted(unique_ids, sampled_ids), minlength=n_ids)
        curr_sample_weight *= sample_counts_for_ids[np.searchsorted(unique_ids, ids)]

        if class_weight == "subsample":
            with catch_warnings():
                simplefilter("ignore", DeprecationWarning)
                curr_sample_weight *= compute_sample_weight("auto", event, indices=indices)
        elif class_weight == "balanced_subsample":
            curr_sample_weight *= compute_sample_weight("balanced", event, indices=indices)

        tree.fit(X[indices], {'id': ids[indices], 'time_start': time_start[indices], 'time_stop': time_stop[indices], 'event': event[indices]}, sample_weight=curr_sample_weight[indices], check_input=False)
    else:
        tree.fit(X, y, sample_weight=sample_weight, check_input=False)

    return tree



In [None]:
from joblib import dump, load

tree = RecurrentTree()
dump(tree, 'test.pkl')  # 직렬화 시도
loaded_tree = load('test.pkl')  # 역직렬화 시도

In [None]:
loaded_tree

In [None]:
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from joblib import Parallel, delayed
from sklearn.utils.validation import check_is_fitted
import warnings
import numpy as np

class RecurrentRandomForest(BaseEstimator):
    """
    A Random Forest model designed for recurrent event data.
    """
    def __init__(self, n_estimators=100, max_depth=None, min_ids_split=2,
                 min_ids_leaf=1, bootstrap=True, oob_score=False, n_jobs=None,
                 random_state=None, verbose=0, warm_start=False, max_samples=None,
                 min_impurity_decrease=0.0, max_features=None, max_thresholds=None):  # Add new parameters
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_ids_split = min_ids_split
        self.min_ids_leaf = min_ids_leaf
        self.bootstrap = bootstrap
        self.oob_score = oob_score
        self.n_jobs = n_jobs
        self.random_state = random_state
        self.verbose = verbose
        self.warm_start = warm_start
        self.max_samples = max_samples
        self.min_impurity_decrease = min_impurity_decrease
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.estimators_ = [self._make_estimator(random_state=i) for i in range(self.n_estimators)]

    def _make_estimator(self, random_state=None):
        """
        Constructs a new instances of the 'RecurrentTree' with the specified hyperparameters
        Allows for creating each tree with a different 'random_state' for randomness
        """
        return RecurrentTree(
            max_depth=self.max_depth,
            min_ids_split=self.min_ids_split,  # Note: changed the parameter name
            min_ids_leaf=self.min_ids_leaf,    # Note: changed the parameter name
            random_state=random_state,
            min_impurity_decrease=self.min_impurity_decrease,
            max_features=self.max_features,
            max_thresholds=self.max_thresholds  # New parameter
        )

    def fit(self, X, y, sample_weight=None):
        """
        Trains the random forest using the input data
        """
        X = self._validate_data(X, accept_sparse='csc', ensure_min_samples=2)
        ids = y['id']
        time_start = y['time_start']
        time_stop = y['time_stop']
        event = y['event']
        self.n_features_in_ = X.shape[1]

        n_samples_bootstrap = _get_n_samples_bootstrap(len(np.unique(ids)), self.max_samples)

        def _fit_tree(tree):
            if self.bootstrap:
                unique_ids = np.unique(ids)
                sampled_ids = _generate_sample_indices(tree.random_state, unique_ids, n_samples_bootstrap)
                bootstrap_indices = np.where(np.isin(ids, sampled_ids))[0]
                X_bootstrap = X[bootstrap_indices]
                ids_bootstrap = np.array(ids)[bootstrap_indices]
                time_start_bootstrap = np.array(time_start)[bootstrap_indices]
                time_stop_bootstrap = np.array(time_stop)[bootstrap_indices]
                event_bootstrap = np.array(event)[bootstrap_indices]
            else:
                X_bootstrap = X
                ids_bootstrap = ids
                time_start_bootstrap = time_start
                time_stop_bootstrap = time_stop
                event_bootstrap = event

            tree.fit(X_bootstrap, ids_bootstrap, time_start_bootstrap, time_stop_bootstrap, event_bootstrap)
            return tree

        self.estimators_ = Parallel(n_jobs=self.n_jobs)(
            delayed(_fit_tree)(tree) for tree in self.estimators_
        )

        if self.oob_score:
            self._set_oob_score_and_attributes(X, y)
        return self

    def _set_oob_score_and_attributes(self, X, y):
        """
        Calculates the out-of-bag (OOB) scores using the ensemble's predictions for the training data samples that were not seen during the training of a given tree.
        Also sets the 'oob_prediction_' and 'oob_score_' attributes of the class.
        """
        n_samples = X.shape[0]

        # Assuming y is a structured array with these keys.
        ids = y['id']
        time_start = y['time_start']
        time_stop = y['time_stop']
        event = y['event']

        all_predictions = []
        n_samples_bootstrap = _get_n_samples_bootstrap(len(np.unique(ids)), self.max_samples)

        for estimator in self.estimators_:
            unsampled_indices = _generate_unsampled_indices(estimator.random_state, np.unique(ids), n_samples_bootstrap)
            p_estimator_result = estimator.predict_mean_function(X[unsampled_indices, :], ids[unsampled_indices])
            all_predictions.extend(p_estimator_result)

        # Aggregate the predictions for each unique ID
        aggregated_predictions = self._aggregate_predictions(all_predictions, ids)

        self.oob_prediction_ = aggregated_predictions

        # Assuming the presence of a utility function to calculate the C-index based on aggregated predictions
        self.oob_score_ = self._estimate_recurrent_concordance_index(aggregated_predictions, X, event, ids)


    def _estimate_recurrent_concordance_index(self, predictions, X, event, ids, total_events):
        """
        Estimate the C-index for recurrent events using OOB ensemble estimates for right-censored data.

        Parameters:
        - predictions: Predicted mean functions for all samples using OOB.
        - X: The data matrix.
        - event: Observed recurrent events for all samples.
        - ids: IDs for each event.
        - total_events: Total number of events for each ID.

        Returns:
        - C-index estimate.
        """
        unique_ids = np.unique(ids)
        n_unique_ids = len(unique_ids)

        id_to_avg_prediction = {}
        for uid in unique_ids:
            uid_indices = np.where(ids == uid)[0]
            uid_predictions = [predictions[i] for i in uid_indices]

            # 각 원소별로 평균을 계산
            max_length = max(map(len, uid_predictions))
            avg_prediction = []
            for i in range(max_length):
                avg_prediction.append(np.mean([pred[i] for pred in uid_predictions if i < len(pred)]))

            id_to_avg_prediction[uid] = avg_prediction

        concordant_pairs = 0
        permissible_pairs = 0

        for i in range(n_unique_ids):
            for j in range(i+1, n_unique_ids):
                uid_i = unique_ids[i]
                uid_j = unique_ids[j]

                right_censored_i = event[ids == uid_i][0] < total_events[uid_i]
                right_censored_j = event[ids == uid_j][0] < total_events[uid_j]

                if not right_censored_i and not right_censored_j:  # Both are not right-censored
                    if event[ids == uid_i][0] > event[ids == uid_j][0]:
                        permissible_pairs += 1
                        if id_to_avg_prediction[uid_i] > id_to_avg_prediction[uid_j]:
                            concordant_pairs += 1
                else:  # At least one is right-censored
                    if not right_censored_i:  # i is not right-censored but j is
                        permissible_pairs += 1
                        if id_to_avg_prediction[uid_i] > id_to_avg_prediction[uid_j]:
                            concordant_pairs += 1
                    elif not right_censored_j:  # j is not right-censored but i is
                        permissible_pairs += 1
                        if id_to_avg_prediction[uid_i] < id_to_avg_prediction[uid_j]:
                            concordant_pairs += 1

        c_index = concordant_pairs / permissible_pairs if permissible_pairs > 0 else 0
        return 2 * c_index

        """Validate input data('X') to ensure it's in the correct format and meets the necessary conditions for processing."""
    def _validate_data(self, X, accept_sparse=False, ensure_min_samples=1):
        """Validate input data('X') to ensure it's in the correct format and meets the necessary conditions for processing."""
        return check_array(X, accept_sparse=accept_sparse, ensure_min_samples=ensure_min_samples)

    def _validate_X_predict(self, X):
        """Validate X whenever one tries to predict."""
        X = check_array(X)
        if X.shape[1] != self.n_features_in_:
            raise ValueError("Number of features of the model must match the input. Model n_features is {} and input n_features is {}."
                             .format(self.n_features_in_, X.shape[1]))
        return X

    def _aggregate_predictions(self, all_predictions, ids):
        """Aggregate predictions for each unique ID."""
        # Initialize aggregated predictions dictionary
        aggregated_predictions = {}

        # Iterate over each unique ID
        for uid in np.unique(ids):
            uid_indices = np.where(ids == uid)[0]
            uid_predictions = [all_predictions[i] for i in uid_indices]

            # Aggregate unique times, n_at_risk, and n_events
            unique_times = sorted(np.unique(np.concatenate([pred['unique_times'] for pred in uid_predictions])))
            n_at_risk = np.zeros(len(unique_times))
            n_events = np.zeros(len(unique_times))

            for pred in uid_predictions:
                for i, t in enumerate(unique_times):
                    if t in pred['unique_times']:
                        idx = pred['unique_times'].index(t)
                        n_at_risk[i] += pred['n_at_risk'][idx]
                        n_events[i] += pred['n_events'][idx]

            # Calculate the mean function
            aggregated_predictions[uid] = self._calculate_mean_function(unique_times, n_at_risk, n_events)

        return aggregated_predictions

    def _calculate_mean_function(self, unique_times, n_at_risk, n_events):
        """Calculate the mean function based on unique times, n_at_risk, and n_events."""
        mean_function_values = []
        cumulative_hazard = 0
        for i in range(len(unique_times)):
            hazard = n_events[i] / n_at_risk[i]
            cumulative_hazard += hazard
            mean_function_values.append(cumulative_hazard)
        return {
            'unique_times': unique_times,
            'mean_function': mean_function_values
        }

    def predict_mean_function(self, X, ids):
        """
        Predict the Nelson-Aalen estimator of the mean function for given samples.
        """
        X = self._validate_X_predict(X)

        # Ensure ids is iterable
        if not isinstance(ids, (list, np.ndarray)):
            ids = [ids]

        # Initialize all_predictions list
        all_predictions = []

        for tree in self.estimators_:
            mean_predictions = tree.predict_mean_function(X, ids)
            for idx, func in enumerate(mean_predictions):
                predictions.append({
                    'unique_times': list(range(len(func))),
                    'mean_function': func,
                    # For now, assuming n_at_risk and n_events to be ones, this might need to be adjusted based on actual logic
                    'n_at_risk': [1] * len(func),
                    'n_events': [1] * len(func)
                })
            all_predictions.extend(mean_predictions)

        # Aggregate the predictions for each unique ID
        aggregated_predictions = self._aggregate_predictions(all_predictions, ids)

        return aggregated_predictions

    def predict_rate_function(self, X, ids):
        """Predict the rate function using the derivative of the mean function."""
        mean_functions = self.predict_mean_function(X, ids)
        rate_functions = {}

        for uid, mean_function in mean_functions.items():
            unique_times = mean_function['unique_times']
            rates = np.diff(mean_function['mean_function']) / np.diff(unique_times)
            rate_functions[uid] = {
                'unique_times': unique_times[:-1],  # 마지막 시간 포인트는 제외합니다.
                'rates': rates
            }

        return rate_functions

In [None]:
x = data[['group','x1','gender']].values

In [None]:
rrf = RecurrentRandomForest(n_estimators=100, max_depth=3, min_ids_leaf=10, min_impurity_decrease=0.2, random_state=1190, oob_score=True, n_jobs=6)
y = {
    'id': ids,
    'time_start': time_start,
    'time_stop': time_stop,
    'event': event
}
rrf.fit(x,y)

In [None]:
rrf.predict_mean_function(X=x, ids=1)

In [None]:
rrf.predict_mean_function(X=x, ids=2)

In [None]:
rrf.oob_score_

In [None]:
"""
보류
"""

class PermutationImportance:
    def __init__(self, model, n_repeats=30, random_state=None):
        """
        Constructor of the class
        'model': The trained model for which we want to compute feature importances
        'n_repeats': Number of times to repeat the permutation for each feature to get a reliable estimate.
        'random_state': Seed for reproducibility
        """
        self.model = model
        self.n_repeats = n_repeats
        self.random_state = random_state

    def _compute_baseline_cindex(self, X, event, time_stop, ids):
        """
        Computes the baseline C-index using the original (non-permuted) data.
        The C-index is a metric to evaluate the model's predictions, especially for recurrent events.
        This function first predicts the mean function using the model, averages the predictions, and then calculates the C-index.
        """
        # Get predictions using the model
        predictions = self.model.predict_mean_function(X)

        # Take the mean of the predictions for each individual for the C-index computation
        mean_predictions = np.mean(predictions, axis=1)

        # Compute total number of events for each ID
        unique_ids, total_events = np.unique(ids, return_counts=True)
        total_events_dict = dict(zip(unique_ids, total_events))
        total_events_arr = np.array([total_events_dict[id_] for id_ in ids])

        return self._estimate_concordance_index_recurrent(time_stop, event, mean_predictions, ids, total_events_arr)

    def _estimate_concordance_index_recurrent(self, time_stop, event, predictions, ids, total_events):
        """
        A wrapper around the model's method to estimate the recurrent C-index.
        It's used for convenience and to make the code more readable.
        """
        return self.model._estimate_recurrent_concordance_index(predictions, time_stop, event, ids, total_events)

    def compute_importance(self, X, event, time_stop, ids):
        """
        The main function that computes the feature importances
        For each feature:
            It permutes (shuffles) the feature's values a number of times (specified by 'n_repeats')
            For each permutation, it calculates the drop in C-index (compared to the baseline C-index) due to the permutation
            The drop in performance (C-index) due to the permutation gives an indicartion of the feature's importance.
        It returns the computed importances matrix where each row corresponds to a feature and each column to a permutation repitition
        """
        baseline_cindex = self._compute_baseline_cindex(X, event, time_stop, ids)

        rng = check_random_state(self.random_state)

        n_features = X.shape[1]
        importances = np.zeros((n_features, self.n_repeats))

        unique_ids = np.unique(ids)

        for feature in range(n_features):
            for repeat in range(self.n_repeats):
                X_permuted = X.copy()

                permuted_ids = rng.permutation(unique_ids)  # ID를 섞습니다.

                # ID에 따라 값을 변경합니다.
                for orig_id, new_id in zip(unique_ids, permuted_ids):
                    orig_idx = np.where(ids == orig_id)[0]
                    new_idx = np.where(ids == new_id)[0]
                    X_permuted[orig_idx, feature] = X[new_idx, feature]

                # Calculate c-index for permuted X
                permuted_cindex = self._compute_baseline_cindex(X_permuted, event, time_stop, ids)

                # The importance is the drop in c-index
                importances[feature, repeat] = baseline_cindex - permuted_cindex

        return importances

    def report_importance(self, X, event, time_stop, ids):
        """
        Computes the feature importances and then calculates the mean and standard deviation of the importances for each feature across all the repeats.
        The mean gives an average measure of the importance of each feature, while the standard deviation provides an estimate of the variability or uncertainty in the importance estimates.
        It returns the mean and standard deviation of the feature importances.
        """
        importances = self.compute_importance(X, event, time_stop, ids)

        # Compute mean and std of importances
        importance_mean = np.mean(importances, axis=1)
        importance_std = np.std(importances, axis=1)

        return importance_mean, importance_std


In [None]:
permutation_importance = PermutationImportance(rrf, n_repeats=10, random_state=42)
mean_importances, std_importances = permutation_importance.report_importance(x, event, time_stop, ids)

print("Mean Importances:", mean_importances)
print("STD Importances:", std_importances)