In [1]:
!pip install scikit-survival



In [2]:
import numpy as np
import pandas as pd


path="/Users/jeongsookim/Downloads"
data = pd.read_csv(f"{path}/simuDat.csv")

ids = data['id'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
x = data[['group','x1','gender']].values

### RisksetCounter

In [3]:
class RisksetCounter:
    def __init__(self, ids, time_start, time_stop, event):
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.all_unique_times = np.unique(time_stop)
        self.n_unique_times = len(self.all_unique_times)

        self.n_at_risk = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events = np.zeros(self.n_unique_times, dtype=np.int64)
        self.set_data()

        self.state_stack = []

    def set_data(self):
        unique_ids = set(self.ids)
        for t_idx, t in enumerate(self.all_unique_times):
            self.n_at_risk[t_idx] = sum([self.Y_i(id_, t_idx) for id_ in unique_ids])
            self.n_events[t_idx] = sum([self.dN_bar_i(id_, t_idx) for id_ in unique_ids])

    def Y_i(self, id_, t_idx):
        if t_idx >= len(self.all_unique_times):
            return 0
        time_at_t_idx = self.all_unique_times[t_idx]
        indices = (self.ids == id_) & (time_at_t_idx <= self.time_stop)
        return np.any(indices)

    def dN_bar_i(self, id_, t_idx):
        if t_idx >= len(self.all_unique_times):
            return 0
        time_at_t_idx = self.all_unique_times[t_idx]
        indices = (self.ids == id_) & (time_at_t_idx == self.time_stop) & (self.event == 1)
        return np.any(indices)

    def save_state(self):
        self.state_stack.append((self.ids.copy(), self.time_start.copy(), self.time_stop.copy(), self.event.copy(), self.n_at_risk.copy(), self.n_events.copy()))

    def load_state(self):
        if self.state_stack:
            self.ids, self.time_start, self.time_stop, self.event, self.n_at_risk, self.n_events = self.state_stack.pop()

    def update(self, new_ids, new_time_start, new_time_stop, new_event): 
        # Save the current state
        self.save_state()    

        # Compute the intersection of data
        mask = np.isin(self.ids, new_ids)
    
        # Extract data of the intersection
        updated_ids = self.ids[mask]
        updated_time_start = self.time_start[mask]
        updated_time_stop = self.time_stop[mask]
        updated_event = self.event[mask]

        # Update object variables based on the intersection data
        self.ids = updated_ids
        self.time_start = updated_time_start
        self.time_stop = updated_time_stop
        self.event = updated_event

        # Recalculate unique times based on the updated data
        self.all_unique_times = np.unique(np.concatenate([self.time_start, self.time_stop]))
        self.n_unique_times = len(self.all_unique_times)
    
        # Resize the n_at_risk and n_events arrays based on the updated unique times
        self.n_at_risk = np.zeros(self.n_unique_times, dtype=np.int64)
        self.n_events = np.zeros(self.n_unique_times, dtype=np.int64)

        # Update the n_at_risk and n_events arrays
        unique_ids = set(self.ids)  # Extract unique IDs to avoid redundant calculations
        for t_idx, t in enumerate(self.all_unique_times):
            self.n_at_risk[t_idx] = sum([self.Y_i(id_, t_idx) for id_ in unique_ids])
            self.n_events[t_idx] = sum([self.dN_bar_i(id_, t_idx) for id_ in unique_ids])

    def reset(self):
        self.load_state()

    def copy(self):
        return RisksetCounter(self.ids.copy(), self.time_start.copy(), self.time_stop.copy(), self.event.copy())

    def __reduce__(self):
        return (self.__class__, (self.ids, self.time_start, self.time_stop, self.event))

# Let's check if the updated RisksetCounter works
riskset_test = RisksetCounter(ids, time_start, time_stop, event)
riskset_test.n_at_risk, riskset_test.n_events, len(riskset_test.n_at_risk), len(riskset_test.n_events)


(array([100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100, 100, 100, 100, 100, 100, 100, 100,  98,  98,  98,  98,
         98,  98,  98,  98,  98,  98,  98,  98,  98,  98,  98,  98,  98,
         98,  98,  98,  98,  98,  98,  98,  98,  88,  88,  88,  88,  88,
         88,  88,  88,  88,  88,  88,  88,  88,  88,  88,  88,  88,  88,
         88,  88,  88,  88,  88,  88,  88,  68,  68,  68,  68,  68,  68,
         68,  68,  68,  68,  68,  68,  68,  68,  68,  68,  68,  68,  68,
         68,  68,  68,  68,  68,  68,  68,  68]),
 array([2, 4, 3, 2, 2, 1, 1, 1, 1, 3, 1, 6, 5, 2, 3, 1, 2, 2, 2, 1, 4, 3,
        5, 3, 1, 4, 3, 2, 6, 3, 3, 5, 4, 1, 2, 2, 4, 3, 2, 1, 1, 1, 3, 4,

In [4]:
import numpy as np
import pandas as pd


path="/Users/jeongsookim/Downloads"
simuDat= pd.read_csv(f"{path}/simuDat.csv")
simuDat

Unnamed: 0.1,Unnamed: 0,id,start,stop,event,group,x1,gender
0,1,1,0,1,1,0,-1.93,1
1,2,1,1,22,1,0,-1.93,1
2,3,1,22,23,1,0,-1.93,1
3,4,1,23,57,1,0,-1.93,1
4,5,1,57,112,0,0,-1.93,1
...,...,...,...,...,...,...,...,...
495,496,100,119,123,1,1,-0.93,1
496,497,100,123,124,1,1,-0.93,1
497,498,100,124,134,1,1,-0.93,1
498,499,100,134,136,1,1,-0.93,1


In [5]:
riskset_counter = RisksetCounter(ids=simuDat["id"].values, 
                                       time_start=simuDat["start"].values, 
                                       time_stop=simuDat["stop"].values, 
                                       event=simuDat["event"].values)

#### Group==Contr

In [6]:
simuDat_male = simuDat[simuDat["group"] == 0]

In [7]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [8]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

(array([55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55,
        55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55,
        55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55,
        55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55, 55,
        55, 55, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54,
        54, 54, 54, 54, 54, 54, 54, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48,
        48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 39, 39, 39, 39,
        39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39, 39,
        39, 39, 39]),
 139)

In [9]:
riskset_counter.n_events, len(riskset_counter.n_events)

(array([0, 2, 3, 3, 1, 2, 1, 1, 1, 2, 1, 5, 4, 2, 3, 1, 2, 1, 1, 1, 3, 3,
        5, 2, 1, 1, 3, 2, 4, 3, 2, 2, 1, 1, 1, 2, 3, 3, 2, 1, 1, 1, 3, 4,
        2, 3, 3, 2, 2, 3, 2, 3, 3, 1, 3, 2, 1, 1, 2, 3, 2, 2, 2, 1, 1, 1,
        1, 1, 1, 2, 1, 3, 2, 3, 1, 2, 2, 1, 1, 1, 3, 4, 2, 5, 1, 2, 1, 4,
        3, 3, 1, 1, 2, 3, 1, 2, 1, 1, 2, 1, 3, 1, 1, 1, 1, 2, 4, 1, 2, 2,
        3, 2, 4, 3, 2, 2, 1, 1, 2, 3, 2, 1, 1, 1, 1, 2, 4, 2, 1, 2, 3, 1,
        1, 2, 2, 3, 1, 2, 0]),
 139)

In [10]:
np.unique(riskset_counter.ids)

array([ 1,  3,  4,  5,  8,  9, 11, 12, 13, 16, 19, 21, 22, 25, 27, 28, 29,
       30, 31, 32, 33, 34, 35, 36, 40, 43, 44, 46, 47, 48, 49, 51, 52, 53,
       54, 55, 56, 58, 63, 64, 65, 66, 68, 71, 72, 76, 77, 82, 84, 86, 90,
       91, 93, 95, 99])

In [11]:
riskset_counter.reset()

In [12]:
simuDat_male = simuDat[simuDat["group"] == 1]

In [13]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [14]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

(array([45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45,
        45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 45, 44,
        44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 44, 40, 40, 40,
        40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40, 29,
        29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
        29]),
 86)

In [15]:
riskset_counter.n_events, len(riskset_counter.n_events)

(array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 2, 1, 3, 3, 1, 1, 1, 2, 1, 2,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 2, 1, 1, 2,
        1, 1, 2, 3, 2, 1, 1, 2, 2, 1, 1, 3, 1, 3, 3, 3, 2, 2, 1, 2, 1, 1,
        0, 1, 1, 3, 1, 2, 4, 3, 1, 3, 1, 1, 3, 1, 1, 1, 2, 1, 1, 0]),
 86)

In [16]:
np.unique(riskset_counter.ids)

array([  2,   6,   7,  10,  14,  15,  17,  18,  20,  23,  24,  26,  37,
        38,  39,  41,  42,  45,  50,  57,  59,  60,  61,  62,  67,  69,
        70,  73,  74,  75,  78,  79,  80,  81,  83,  85,  87,  88,  89,
        92,  94,  96,  97,  98, 100])

#### Male

In [17]:
riskset_counter.reset()

In [18]:
simuDat_male = simuDat[simuDat["gender"] == 0]

In [19]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [20]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

(array([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
        50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
        50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
        50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
        50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 46, 46, 46, 46, 46, 46, 46,
        46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 46, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37]),
 124)

In [21]:
riskset_counter.n_events, len(riskset_counter.n_events)

(array([0, 1, 1, 2, 1, 1, 1, 2, 3, 2, 1, 1, 1, 1, 2, 3, 1, 2, 1, 3, 3, 1,
        4, 3, 1, 3, 1, 1, 3, 3, 1, 1, 1, 1, 2, 1, 2, 4, 3, 2, 2, 3, 1, 3,
        2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 3, 1, 3, 3, 2, 1, 1, 1, 3,
        3, 3, 1, 2, 2, 1, 1, 4, 3, 2, 1, 2, 3, 4, 1, 1, 1, 3, 2, 2, 1, 2,
        2, 3, 3, 2, 2, 3, 2, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 1, 2, 1, 1, 2,
        6, 1, 1, 3, 2, 3, 1, 2, 1, 2, 2, 2, 3, 0]),
 124)

In [22]:
np.unique(riskset_counter.ids)

array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75])

In [23]:
np.unique(simuDat_male['id'].values)

array([26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42,
       43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
       60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75])

#### Female

In [24]:
riskset_counter.reset()

In [25]:
simuDat_male = simuDat[simuDat["gender"] == 1]

In [26]:
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [27]:
riskset_counter.n_events, len(riskset_counter.n_events)


(array([0, 1, 4, 2, 2, 1, 1, 1, 3, 3, 1, 2, 1, 1, 1, 2, 1, 2, 4, 1, 1, 1,
        2, 2, 2, 3, 2, 2, 1, 1, 1, 3, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 3,
        3, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 1, 4, 1, 1,
        1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 4, 2, 1, 2, 4, 1, 2, 3, 1,
        3, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 0]),
 107)

In [28]:
riskset_counter.n_at_risk, len(riskset_counter.n_at_risk)

(array([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
        50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
        50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50,
        50, 50, 50, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48,
        48, 48, 48, 48, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42,
        42, 42, 42, 42, 42, 42, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31,
        31, 31, 31, 31, 31]),
 107)

In [29]:
riskset_counter.n_events, len(riskset_counter.n_events)


(array([0, 1, 4, 2, 2, 1, 1, 1, 3, 3, 1, 2, 1, 1, 1, 2, 1, 2, 4, 1, 1, 1,
        2, 2, 2, 3, 2, 2, 1, 1, 1, 3, 2, 1, 2, 1, 1, 2, 2, 1, 1, 1, 1, 3,
        3, 1, 1, 1, 1, 1, 2, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 1, 4, 1, 1,
        1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 2, 1, 1, 4, 2, 1, 2, 4, 1, 2, 3, 1,
        3, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 0]),
 107)

In [30]:
np.unique(riskset_counter.ids)

array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  76,
        77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
        90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100])

In [31]:
np.unique(simuDat_male['id'].values)

array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  76,
        77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,
        90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100])

#### Contr&male

In [32]:
simuDat_sub1 = simuDat[(simuDat["group"] == 0)&(simuDat['gender']==0)]

In [34]:
riskset_counter.reset()

In [35]:
np.unique(riskset_counter.ids)

array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100])

In [36]:
simuDat_male = simuDat[simuDat["group"] == 0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [37]:
simuDat_male = simuDat[simuDat['gender'] == 0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [38]:
np.unique(riskset_counter.ids)

array([27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 40, 43, 44, 46, 47, 48, 49,
       51, 52, 53, 54, 55, 56, 58, 63, 64, 65, 66, 68, 71, 72])

In [39]:
np.unique(simuDat_sub1['id'].values)

array([27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 40, 43, 44, 46, 47, 48, 49,
       51, 52, 53, 54, 55, 56, 58, 63, 64, 65, 66, 68, 71, 72])

#### Contr & female

In [40]:
simuDat_sub1 = simuDat[(simuDat["group"] == 0)&(simuDat['gender']==1)]

In [42]:
riskset_counter.reset()

In [43]:
np.unique(riskset_counter.ids)

array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100])

In [44]:
simuDat_male = simuDat[simuDat["group"] == 0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [45]:
simuDat_male = simuDat[simuDat["gender"] == 1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [46]:
np.unique(riskset_counter.ids)

array([ 1,  3,  4,  5,  8,  9, 11, 12, 13, 16, 19, 21, 22, 25, 76, 77, 82,
       84, 86, 90, 91, 93, 95, 99])

In [47]:
np.unique(simuDat_sub1['id'].values)

array([ 1,  3,  4,  5,  8,  9, 11, 12, 13, 16, 19, 21, 22, 25, 76, 77, 82,
       84, 86, 90, 91, 93, 95, 99])

#### Treat&male

In [48]:
simuDat_sub1 = simuDat[(simuDat["group"] == 1)&(simuDat['gender']==0)]

In [50]:
riskset_counter.reset()

In [51]:
simuDat_male = simuDat[simuDat['group']==1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [52]:
simuDat_male = simuDat[simuDat['gender']==0]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [53]:
np.unique(riskset_counter.ids)

array([26, 37, 38, 39, 41, 42, 45, 50, 57, 59, 60, 61, 62, 67, 69, 70, 73,
       74, 75])

In [54]:
np.unique(simuDat_sub1['id'].values)

array([26, 37, 38, 39, 41, 42, 45, 50, 57, 59, 60, 61, 62, 67, 69, 70, 73,
       74, 75])

#### Treat&female

In [55]:
simuDat_sub1 = simuDat[(simuDat["group"] == 1)&(simuDat['gender']==1)]

In [57]:
riskset_counter.reset()

In [58]:
simuDat_male = simuDat[simuDat['group']==1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [59]:
simuDat_male = simuDat[simuDat['gender']==1]
riskset_counter.update(simuDat_male['id'].values, simuDat_male['start'].values, simuDat_male['stop'].values, simuDat_male['event'].values)

In [60]:
np.unique(riskset_counter.ids)

array([  2,   6,   7,  10,  14,  15,  17,  18,  20,  23,  24,  78,  79,
        80,  81,  83,  85,  87,  88,  89,  92,  94,  96,  97,  98, 100])

In [61]:
np.unique(simuDat_sub1['id'].values)

array([  2,   6,   7,  10,  14,  15,  17,  18,  20,  23,  24,  78,  79,
        80,  81,  83,  85,  87,  88,  89,  92,  94,  96,  97,  98, 100])

In [6]:
def argbinsearch(arr, key_val):
    arr_len = len(arr)
    min_idx = 0
    max_idx = arr_len

    while min_idx < max_idx:
        mid_idx = min_idx + ((max_idx - min_idx) // 2)

        if mid_idx < 0 or mid_idx >= arr_len:
            return -1

        mid_val = arr[mid_idx]
        if mid_val <= key_val:  # Change the condition to <=
            min_idx = mid_idx + 1
        else:
            max_idx = mid_idx

    return min_idx

이 함수는 argbinsearch라는 이름의 함수로, 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 이진 탐색으로 찾아 반환합니다.

자세한 코드 설명을 아래에 제공합니다:

1. 입력:

  * arr: 탐색 대상인 정렬된 배열
  key_val: 찾고자 하는 키 값

2. 초기 변수 설정:

  * arr_len: 배열의 길이를 저장합니다.
  * min_idx: 탐색 범위의 최솟값으로, 처음에는 배열의 시작 인덱스인 0으로 설정됩니다.
  * max_idx: 탐색 범위의 최댓값으로, 처음에는 배열의 길이로 설정됩니다.

3. 이진 탐색:

  * while 루프를 사용하여 min_idx가 max_idx보다 작은 동안 탐색을 반복합니다.
  * mid_idx: 현재 탐색 범위의 중간 인덱스를 계산합니다.
  * mid_val: 중간 인덱스에 해당하는 배열의 원소 값을 가져옵니다.

4. 키 값과 중간 값을 비교합니다:
  * 만약 중간 값이 키 값보다 작거나 같으면, min_idx를 mid_idx + 1로 업데이트합니다. 이렇게 하면 탐색 범위의 왼쪽 부분을 제외하게 됩니다.
  * 그렇지 않으면, max_idx를 mid_idx로 업데이트합니다. 이렇게 하면 탐색 범위의 오른쪽 부분을 제외하게 됩니다.

5. 결과 반환:

  * 루프가 종료되면, min_idx는 키 값보다 크거나 같은 첫 번째 원소의 인덱스를 가리키게 됩니다. 따라서 min_idx를 반환합니다.

이 함수는 정렬된 배열에서 주어진 키 값보다 크거나 같은 첫 번째 원소의 위치를 효율적으로 찾기 위해 사용됩니다. 이진 탐색은 배열의 중간 값을 반복적으로 확인하면서 탐색 범위를 절반씩 줄여나가므로, 큰 배열에서도 빠르게 원하는 값을 찾을 수 있습니다.

### PseudoScoreCriterion

#### V2. Criterion

In [7]:
class PseudoScoreCriterion:
    def __init__(self, n_outputs, n_samples, unique_times, x, ids, time_start, time_stop, event):
        """
        Constructor of the class
        Initialize instance variables using the provided input parameters
        Objects 'riskset_left', 'riskset_right', and 'riskset_total' are initialized using the 'RisksetCounter' class
        """
        self.n_outputs = n_outputs
        self.n_samples = n_samples
        self.unique_times = unique_times
        self.x = x
        self.ids = ids
        self.time_start = time_start
        self.time_stop = time_stop
        self.event = event

        self.unique_ids = set(self.ids)  # Store unique ids for later use
        self.unique_times = unique_times

        self.riskset_left = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_right = RisksetCounter(ids, time_start, time_stop, event)
        self.riskset_total = RisksetCounter(ids, time_start, time_stop, event)

        self.samples_time_idx = np.searchsorted(unique_times, time_stop)

        self.split_pos = 0
        self.split_time_idx = 0

        self._riskset_counter = RisksetCounter(ids, time_start, time_stop, event)  # 새로 추가

    def init(self, y, sample_weight, n_samples, samples, start, end):
        """
        Initialization function
        Reset the risk set counters ('riskset_left','riskset_right','riskset_total') and updates 'riskset_total' with new data
        """
        self.samples = samples
        self.riskset_left.reset()
        self.riskset_right.reset()
        self.riskset_total.reset()

        time_starts, stop_times, events = y[:, 0], y[:, 1], y[:, 2]
        ids_for_update = [self.ids[idx] for idx in samples[start:end]]
        time_starts_for_update = [time_starts[idx] for idx in samples[start:end]]
        stop_times_for_update = [stop_times[idx] for idx in samples[start:end]]
        events_for_update = [events[idx] for idx in samples[start:end]]

        # Combine unique times from both datasets
        self.unique_times = np.unique(np.concatenate([self.unique_times, stop_times_for_update]))

        self.riskset_total.update(ids_for_update, time_starts_for_update, stop_times_for_update, events_for_update)

    def set_unique_times(self, unique_times):
        """Sets the unique times for the current node."""
        self.unique_times = unique_times

## Group Indicator만으로 나누기...

    # Functions returning the risk set value and event value for the given ID and time index from the respective risk set (left or right)
    def Y_left_value(self, id_, t):
        return self.riskset_left.Y_i(id_, t)
    
    def Y_right_value(self, id_, t):
        return self.riskset_right.Y_i(id_, t)

    def dN_bar_left_value(self, id_, t):
        return self.riskset_left.dN_bar_i(id_, t)

    def dN_bar_right_value(self, id_, t):
        return self.riskset_right.dN_bar_i(id_, t)

    def temporary_update_riskset(self, riskset_counter, ids, time_start, time_stop, event):
        # Combine and find unique stop times from both nodes
        combined_time_stops = np.concatenate([self.riskset_left.time_stop, self.riskset_right.time_stop])
        unique_time_stops = np.unique(combined_time_stops)

        riskset_counter.all_unique_times = unique_time_stops

        # Resize the n_at_risk and n_events arrays based on the updated unique times
        riskset_counter.n_at_risk = np.zeros(len(unique_time_stops), dtype=np.int64)
        riskset_counter.n_events = np.zeros(len(unique_time_stops), dtype=np.int64)

        # Update the n_at_risk and n_events arrays
        unique_ids = set(ids)  # Extract unique IDs to avoid redundant calculations
        for t_idx, t in enumerate(unique_time_stops):
            riskset_counter.n_at_risk[t_idx] = sum([riskset_counter.Y_i(id_, t_idx) for id_ in unique_ids])
            riskset_counter.n_events[t_idx] = sum([riskset_counter.dN_bar_i(id_, t_idx) for id_ in unique_ids])

    def calculate_numerator(self):
        # Temporary update riskset
        self.temporary_update_riskset(self.riskset_left, self.riskset_left.ids, self.riskset_left.time_start, self.riskset_left.time_stop, self.riskset_left.event)
        self.temporary_update_riskset(self.riskset_right, self.riskset_right.ids, self.riskset_right.time_start, self.riskset_right.time_stop, self.riskset_right.event)
    
        w = (self.riskset_left.n_at_risk * self.riskset_right.n_at_risk) / (self.riskset_left.n_at_risk + self.riskset_right.n_at_risk)
        term = (self.riskset_left.n_events / self.riskset_left.n_at_risk) - (self.riskset_right.n_events / self.riskset_right.n_at_risk)
    
        return np.sum(w * term)

    def calculate_variance_estimate(self):
        """
        Update the variance estimate to be compatible with the provided function.
        """
    
        def var_comp(riskset, id_, uniTimeVec, w_const, max_w_const):
            """
            Compute the variance component for each observation, 
            similar to the var_comp function in the mcfDiff.test R code.
            """
            y_i_tj = np.array([riskset.Y_i(id_, t_idx) for t_idx in range(len(uniTimeVec))])
            yVec = riskset.n_at_risk
            n_i_tj = np.array([riskset.dN_bar_i(id_, t_idx) for t_idx in range(len(uniTimeVec))])
            dLambda = riskset.n_events / (riskset.n_at_risk + 1e-7)  # Avoid division by zero

            res_ij = np.where(yVec > 0, y_i_tj / yVec * (n_i_tj - dLambda), 0)

            max_res_ij = np.max(np.abs(res_ij))
    
            if max_res_ij > 0:
                re_res_ij = res_ij / max_res_ij
                reFactor = np.exp(np.log(max_res_ij) + np.log(max_w_const))
            else:
                re_res_ij = 0
                reFactor = 1
    
            res_const = (w_const / max_w_const) * re_res_ij

            return (np.sum(res_const) * reFactor) ** 2

        # Temporary update riskset
        self.temporary_update_riskset(self.riskset_left, self.riskset_left.ids, self.riskset_left.time_start, self.riskset_left.time_stop, self.riskset_left.event)
        self.temporary_update_riskset(self.riskset_right, self.riskset_right.ids, self.riskset_right.time_start, self.riskset_right.time_stop, self.riskset_right.event)

        # Extract required variables
        uniTimeVec = self.riskset_total.all_unique_times
        w_const = (self.riskset_left.n_at_risk * self.riskset_right.n_at_risk) / (self.riskset_left.n_at_risk + self.riskset_right.n_at_risk)
        max_w_const = np.max(w_const)

        # Calculate variance components for each ID in the left and right nodes
        varList1 = [var_comp(self.riskset_left, id_, uniTimeVec, w_const, max_w_const) 
                    for id_ in np.unique(self.riskset_left.ids)]

        varList2 = [var_comp(self.riskset_right, id_, uniTimeVec, w_const, max_w_const) 
                    for id_ in np.unique(self.riskset_right.ids)]
    
        # Sum the variance components
        varU_1 = np.sum(varList1)
        varU_2 = np.sum(varList2)
    
        return varU_1 + varU_2

    
    def calculate_denominator(self):
        return self.calculate_variance_estimate()

    def proxy_impurity_improvement(self):
        if len(self.riskset_left.n_at_risk) == 0 or len(self.riskset_right.n_at_risk) == 0:
            return -np.inf

        numer = self.calculate_numerator() ** 2
        denom = self.calculate_denominator()

        return numer / (denom + 1e-7)
    
    def update_riskset(self, ids_subset):
        # Update the riskset based on the subset of IDs at the current node
        unique_ids_subset = np.unique(ids_subset)
        self.riskset_counter.update(unique_ids_subset, self.time_start, self.time_stop, self.event)

    def node_value(self):
        """
        Returns the Nelson-Aalen estimator of the mean function μ(t) for the entities in the current node.
        """
        return self.node_value_from_riskset(self.riskset_total)

    def node_value_from_riskset(self, riskset_counter):
        """
        Returns the Nelson-Aalen estimator of the mean function μ(t) for the entities based on provided riskset_counter.
        """
        mu_hat_values = []
    
        # Initialize the cumulative sum of the Nelson-Aalen estimator
        cumsum_Nelson_Aalen = 0
    
        for t_idx, t in enumerate(self.unique_times):
            # Use n_at_risk and n_events from the riskset_counter
            n_at_risk_t = riskset_counter.n_at_risk[t_idx] if t_idx < len(riskset_counter.n_at_risk) else 0
            n_events_t = riskset_counter.n_events[t_idx] if t_idx < len(riskset_counter.n_events) else 0
            
            cumsum_Nelson_Aalen += n_events_t / (n_at_risk_t + 1e-7)  # Avoiding division by zero
            mu_hat_values.append(cumsum_Nelson_Aalen)
        
        return mu_hat_values

    # RisksetCounter의 상태를 저장하고 복원하기 위한 메서드를 추가합니다.
    def save_riskset_state(self):
        self._riskset_counter.save_state()

    def reset_riskset_state(self):
        self._riskset_counter.reset()

    def reset(self):
        """
        Functions to reset all risk set counters
        """
        self.riskset_total.reset()
        self.riskset_left.reset()
        self.riskset_right.reset()

    def copy(self):
        """
        Creates and returns a copy of the current object.
        """
        new_criterion = PseudoScoreCriterion(self.n_outputs, self.n_samples, self.unique_times,
                                             self.x, self.ids, self.time_start, self.time_stop, 
                                             self.event)
        new_criterion.riskset_left = self.riskset_left.copy()
        new_criterion.riskset_right = self.riskset_right.copy()
        new_criterion.riskset_total = self.riskset_total.copy()
        new_criterion.samples_time_idx = self.samples_time_idx.copy()
        if hasattr(self, 'samples'):
            new_criterion.samples = self.samples.copy()

        return new_criterion

def update_with_group_indicator(self, feature_index, group_indicator):
    """
    Update the criterion based on a specified feature and group indicator. 
    This will split the data into left and right nodes based on the provided feature and group indicator.
    """
    # Reset the riskset counters for the left and right nodes
    self.riskset_left.reset()
    self.riskset_right.reset()

    # Determine the split by the feature and group indicator
    left_mask = self.x[:, feature_index] <= group_indicator  # Changed to <= for continuous features
    right_mask = ~left_mask

    # Create empty lists to store the ids, start times, stop times, and events for both left and right splits
    ids_left, start_left, stop_left, event_left = [], [], [], []
    ids_right, start_right, stop_right, event_right = [], [], [], []

    # For each unique ID, decide whether to assign it to the left or right node based on the mask
    for id_ in self.unique_ids:
        id_indices = np.where(self.ids == id_)[0]  # Get all indices for this ID
        if left_mask[id_indices[0]]:
            ids_left.extend([self.ids[i] for i in id_indices])
            start_left.extend([self.time_start[i] for i in id_indices])
            stop_left.extend([self.time_stop[i] for i in id_indices])
            event_left.extend([self.event[i] for i in id_indices])
        else:
            ids_right.extend([self.ids[i] for i in id_indices])
            start_right.extend([self.time_start[i] for i in id_indices])
            stop_right.extend([self.time_stop[i] for i in id_indices])
            event_right.extend([self.event[i] for i in id_indices])

    # Set the all_unique_times for the risk sets of left and right nodes to the current node's unique times
    self.riskset_left.all_unique_times = self.unique_times
    self.riskset_right.all_unique_times = self.unique_times

    # Also, adjust the lengths of n_at_risk and n_events in both riskset_left and riskset_right to match unique_times
    self.riskset_left.n_at_risk = np.zeros(len(self.unique_times), dtype=np.int64)
    self.riskset_left.n_events = np.zeros(len(self.unique_times), dtype=np.int64)
    self.riskset_right.n_at_risk = np.zeros(len(self.unique_times), dtype=np.int64)
    self.riskset_right.n_events = np.zeros(len(self.unique_times), dtype=np.int64)

    # Update the risk sets for the left and right nodes
    self.riskset_left.update(ids_left, start_left, stop_left, event_left)
    self.riskset_right.update(ids_right, start_right, stop_right, event_right)

# 이 함수를 PseudoScoreCriterion 클래스에 추가합니다.
setattr(PseudoScoreCriterion, 'update', update_with_group_indicator)

# 추가로, left node와 right node의 데이터를 반환하는 메소드를 추가합니다.
def get_left_node_data(self):
    return self.riskset_left.ids, self.riskset_left.n_at_risk, self.riskset_left.n_events

def get_right_node_data(self):
    return self.riskset_right.ids, self.riskset_right.n_at_risk, self.riskset_right.n_events

setattr(PseudoScoreCriterion, 'get_left_node_data', get_left_node_data)
setattr(PseudoScoreCriterion, 'get_right_node_data', get_right_node_data)

def calculate_node_value_updated(self, side="left"):
    """
    Calculate the node value based on the updated RisksetCounter using get_left_node_data and get_right_node_data.
    
    Parameters:
        - side (str): Either "left" or "right" to determine which riskset to use for calculation.
    """
    if side == "left":
        ids, n_at_risk, n_events = self.get_left_node_data()
    elif side == "right":
        ids, n_at_risk, n_events = self.get_right_node_data()
    else:
        raise ValueError("Invalid side value. Expected 'left' or 'right'.")
    
    mask = np.isin(self.ids, ids)
    
    time_start_filtered = self.time_start[mask]
    time_stop_filtered = self.time_stop[mask]
    event_filtered = self.event[mask]
    
    riskset_temp = RisksetCounter(ids, time_start_filtered, time_stop_filtered, event_filtered)
    riskset_temp.n_at_risk = n_at_risk
    riskset_temp.n_events = n_events

    return self.node_value_from_riskset(riskset_temp)

# PseudoScoreCriterion 클래스에 위에서 정의한 함수를 추가합니다.
setattr(PseudoScoreCriterion, 'calculate_node_value', calculate_node_value_updated)



PseudoScoreCriterion


__main__.PseudoScoreCriterion

In [8]:
riskset_counter.reset()

In [9]:
n_samples = len(ids)

# Create an instance of the RisksetCounter and PseudoScoreCriterion classes
riskset = RisksetCounter(ids, time_start, time_stop, event)
criterion = PseudoScoreCriterion(n_outputs=1, n_samples=n_samples, unique_times=np.unique(time_stop), x=x, ids=ids, time_start=time_start, time_stop=time_stop, event=event)

criterion

<__main__.PseudoScoreCriterion at 0x103671b10>

In [10]:
feature_index = 0
group_indicator = 0

criterion.update(feature_index, group_indicator)

In [11]:
criterion.get_left_node_data()

(array([ 1,  1,  1,  1,  1,  3,  3,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  8,  8,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9, 11, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 16, 16,
        16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
        16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 19, 21, 21, 22,
        22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 25, 25, 27,
        27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28,
        28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29,
        29, 30, 31, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,
        33, 33, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35, 35, 35, 35, 36, 36,
        36, 36, 36, 36, 36, 36, 40, 43, 43, 43, 43, 44, 44, 44, 44, 44, 44,
        44, 44, 44, 44, 46, 46, 46, 46, 46, 46, 47, 47, 47, 47, 47, 47, 47,
        47, 47, 47, 47, 48, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49, 49,
        49, 

In [12]:
criterion.get_right_node_data()

(array([  2,   6,   6,   6,   7,   7,   7,  10,  14,  14,  14,  14,  15,
         17,  18,  18,  20,  23,  24,  24,  24,  26,  26,  26,  26,  26,
         26,  26,  26,  26,  26,  26,  26,  26,  26,  26,  26,  26,  26,
         26,  26,  26,  37,  38,  38,  38,  38,  38,  38,  38,  38,  38,
         38,  39,  39,  39,  41,  41,  41,  41,  41,  41,  41,  41,  41,
         41,  41,  41,  42,  42,  45,  45,  45,  45,  50,  50,  50,  50,
         57,  57,  59,  59,  59,  59,  60,  60,  60,  60,  61,  61,  61,
         62,  67,  67,  67,  67,  67,  67,  67,  67,  67,  67,  69,  70,
         70,  70,  73,  73,  74,  74,  75,  78,  79,  80,  81,  81,  81,
         83,  83,  83,  83,  83,  83,  83,  83,  83,  83,  83,  83,  83,
         83,  83,  83,  83,  85,  87,  87,  87,  87,  87,  87,  88,  88,
         88,  88,  88,  88,  89,  92,  94,  94,  96,  96,  96,  96,  96,
         96,  97,  98,  98,  98, 100, 100, 100, 100, 100, 100, 100, 100,
        100, 100]),
 array([45, 45, 45, 45, 45, 45,

In [13]:
criterion.calculate_variance_estimate()

670.1013958954195

In [14]:
criterion.calculate_numerator()

52.58547418967586

In [15]:
criterion.calculate_denominator()

670.1013958954195

In [16]:
criterion.proxy_impurity_improvement()

4.12658757656426

In [17]:
criterion.calculate_node_value('left')

[0.03636363629752066,
 0.09090909074380166,
 0.14545454519008266,
 0.163636363338843,
 0.19999999963636364,
 0.21818181778512397,
 0.2363636359338843,
 0.2545454540826446,
 0.2545454540826446,
 0.29090909038016527,
 0.3090909085289256,
 0.3999999992727273,
 0.4727272718677686,
 0.5090909081652892,
 0.5636363626115702,
 0.5818181807603305,
 0.6181818170578512,
 0.6363636352066115,
 0.6545454533553718,
 0.6727272715041321,
 0.7272727259504131,
 0.7818181803966942,
 0.8727272711404959,
 0.9090909074380166,
 0.9272727255867769,
 0.9454545437355372,
 0.9999999981818182,
 1.0363636344793388,
 1.1090909070743802,
 1.163636361520661,
 1.1999999978181817,
 1.2363636341157023,
 1.2545454522644626,
 1.2727272704132229,
 1.2909090885619832,
 1.3272727248595038,
 1.3818181793057847,
 1.4363636337520655,
 1.4727272700495861,
 1.4909090881983464,
 1.5090909063471067,
 1.527272724495867,
 1.581818178942148,
 1.6545454515371893,
 1.69090908783471,
 1.7454545422809908,
 1.7999999967272717,
 1.7999999967

In [18]:
criterion.calculate_node_value('right')

[0.0,
 0.022222222172839506,
 0.022222222172839506,
 0.04444444434567901,
 0.04444444434567901,
 0.04444444434567901,
 0.04444444434567901,
 0.04444444434567901,
 0.06666666651851852,
 0.08888888869135803,
 0.08888888869135803,
 0.11111111086419753,
 0.13333333303703704,
 0.13333333303703704,
 0.13333333303703704,
 0.13333333303703704,
 0.13333333303703704,
 0.15555555520987654,
 0.17777777738271605,
 0.17777777738271605,
 0.19999999955555556,
 0.19999999955555556,
 0.19999999955555556,
 0.22222222172839506,
 0.22222222172839506,
 0.2888888882469136,
 0.2888888882469136,
 0.2888888882469136,
 0.3333333325925926,
 0.3333333325925926,
 0.3555555547654321,
 0.4222222212839506,
 0.4888888878024691,
 0.4888888878024691,
 0.5111111099753086,
 0.5111111099753086,
 0.5333333321481482,
 0.5333333321481482,
 0.5333333321481482,
 0.5333333321481482,
 0.5333333321481482,
 0.5333333321481482,
 0.5333333321481482,
 0.5333333321481482,
 0.5333333321481482,
 0.5555555543209877,
 0.5999999986666666,
 0

In [19]:
riskset_counter.reset()

n_samples = len(ids)

# Create an instance of the RisksetCounter and PseudoScoreCriterion classes
riskset = RisksetCounter(ids, time_start, time_stop, event)
criterion = PseudoScoreCriterion(n_outputs=1, n_samples=n_samples, unique_times=np.unique(time_stop), x=x, ids=ids, time_start=time_start, time_stop=time_stop, event=event)

criterion

<__main__.PseudoScoreCriterion at 0x103673190>

In [20]:
feature_index=2
group_indicator = 0

criterion.update(feature_index, group_indicator)

In [21]:
criterion.get_left_node_data()

(array([26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
        26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27,
        27, 27, 27, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29,
        29, 29, 29, 29, 29, 29, 30, 31, 32, 32, 33, 33, 33, 33, 33, 33, 33,
        33, 33, 33, 33, 33, 33, 33, 34, 34, 34, 35, 35, 35, 35, 35, 35, 35,
        35, 35, 35, 36, 36, 36, 36, 36, 36, 36, 36, 37, 38, 38, 38, 38, 38,
        38, 38, 38, 38, 38, 39, 39, 39, 40, 41, 41, 41, 41, 41, 41, 41, 41,
        41, 41, 41, 41, 42, 42, 43, 43, 43, 43, 44, 44, 44, 44, 44, 44, 44,
        44, 44, 44, 45, 45, 45, 45, 46, 46, 46, 46, 46, 46, 47, 47, 47, 47,
        47, 47, 47, 47, 47, 47, 47, 48, 49, 49, 49, 49, 49, 49, 49, 49, 49,
        49, 49, 49, 49, 50, 50, 50, 50, 51, 51, 51, 51, 51, 51, 51, 51, 51,
        52, 53, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54, 54,
        54, 54, 54, 54, 54, 54, 55, 55, 55, 55, 55, 56, 57, 57, 58, 58, 58,
        58, 

In [22]:
criterion.get_right_node_data()

(array([  1,   1,   1,   1,   1,   2,   3,   3,   4,   4,   5,   5,   5,
          5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,   5,
          5,   6,   6,   6,   7,   7,   7,   8,   8,   9,   9,   9,   9,
          9,   9,   9,   9,   9,   9,  10,  11,  12,  12,  12,  12,  12,
         12,  13,  13,  13,  13,  14,  14,  14,  14,  15,  16,  16,  16,
         16,  16,  16,  16,  16,  16,  16,  16,  16,  16,  16,  16,  16,
         16,  16,  16,  16,  16,  16,  16,  16,  16,  16,  16,  16,  16,
         16,  16,  16,  17,  18,  18,  19,  20,  21,  21,  22,  22,  22,
         22,  22,  22,  22,  22,  22,  22,  22,  22,  22,  22,  22,  23,
         24,  24,  24,  25,  25,  76,  77,  78,  79,  80,  81,  81,  81,
         82,  82,  83,  83,  83,  83,  83,  83,  83,  83,  83,  83,  83,
         83,  83,  83,  83,  83,  83,  84,  84,  84,  84,  84,  84,  84,
         84,  84,  84,  84,  84,  85,  86,  86,  86,  86,  86,  86,  86,
         87,  87,  87,  87,  87,  87,  88,  88,  88

In [23]:
criterion.calculate_numerator()

29.369447779111646

In [24]:
criterion.calculate_denominator()

737.9516382850151

In [25]:
criterion.proxy_impurity_improvement()

1.1688631313803468

In [26]:
criterion.node_value()

[0.019999999980000002,
 0.05999999994000001,
 0.08999999991000002,
 0.10999999989000002,
 0.12999999987000002,
 0.13999999986000003,
 0.14999999985000004,
 0.15999999984000005,
 0.16999999983000005,
 0.19999999980000005,
 0.20999999979000006,
 0.26999999973000005,
 0.31999999968000004,
 0.33999999966000005,
 0.3699999996300001,
 0.3799999996200001,
 0.3999999996000001,
 0.4199999995800001,
 0.43999999956000013,
 0.44999999955000014,
 0.4899999995100002,
 0.5199999994800002,
 0.5699999994300002,
 0.5999999994000003,
 0.6099999993900003,
 0.6499999993500003,
 0.6799999993200003,
 0.6999999993000003,
 0.7599999992400004,
 0.7899999992100004,
 0.8199999991800004,
 0.8699999991300005,
 0.9099999990900005,
 0.9199999990800005,
 0.9399999990600005,
 0.9599999990400006,
 0.9999999990000006,
 1.0299999989700006,
 1.0499999989500006,
 1.0599999989400006,
 1.0699999989300006,
 1.0799999989200006,
 1.1099999988900007,
 1.1499999988500007,
 1.1699999988300007,
 1.2099999987900008,
 1.25999999874000

In [27]:
criterion.calculate_node_value('left')

[0.01999999996,
 0.01999999996,
 0.03999999992,
 0.03999999992,
 0.07999999984,
 0.0999999998,
 0.11999999976,
 0.11999999976,
 0.13999999972,
 0.17999999964,
 0.17999999964,
 0.23999999952,
 0.27999999944,
 0.2999999994,
 0.31999999936,
 0.31999999936,
 0.33999999932,
 0.35999999928000004,
 0.35999999928000004,
 0.35999999928000004,
 0.39999999920000007,
 0.45999999908000005,
 0.47999999904000007,
 0.51999999896,
 0.5399999989200001,
 0.5999999988000001,
 0.6599999986800001,
 0.6799999986400002,
 0.7599999984800001,
 0.8199999983600001,
 0.8399999983200002,
 0.8999999982000002,
 0.9199999981600002,
 0.9399999981200002,
 0.9399999981200002,
 0.9399999981200002,
 0.9999999980000003,
 1.0599999978800003,
 1.0799999978400003,
 1.0999999978000004,
 1.1199999977600004,
 1.1399999977200004,
 1.1799999976400004,
 1.1999999976000004,
 1.2399999975200005,
 1.3199999973600005,
 1.3799999972400006,
 1.3799999972400006,
 1.4199999971600006,
 1.4599999970800006,
 1.5199999969600007,
 1.519999996960

In [28]:
criterion.calculate_node_value('right')

[0.01999999996,
 0.0999999998,
 0.13999999972,
 0.17999999964,
 0.17999999964,
 0.17999999964,
 0.17999999964,
 0.19999999959999998,
 0.19999999959999998,
 0.21999999956,
 0.23999999952,
 0.2999999994,
 0.35999999928,
 0.37999999924,
 0.41999999915999997,
 0.43999999912,
 0.45999999908,
 0.47999999904,
 0.51999999896,
 0.5399999989200001,
 0.5799999988400001,
 0.5799999988400001,
 0.65999999868,
 0.67999999864,
 0.67999999864,
 0.6999999986000001,
 0.6999999986000001,
 0.7199999985600001,
 0.7599999984800001,
 0.7599999984800001,
 0.7999999984000001,
 0.8399999983200002,
 0.8999999982000002,
 0.8999999982000002,
 0.9399999981200002,
 0.9799999980400003,
 0.9999999980000003,
 0.9999999980000003,
 1.0199999979600003,
 1.0199999979600003,
 1.0199999979600003,
 1.0199999979600003,
 1.0399999979200003,
 1.0999999978000004,
 1.0999999978000004,
 1.0999999978000004,
 1.1399999977200004,
 1.1599999976800004,
 1.1999999976000004,
 1.2199999975600004,
 1.2399999975200005,
 1.2799999974400005,
 1

### 이전 버전에서 수정한 것!!!

### V2.

In [29]:
import pandas as pd
import numpy as np
from sklearn.utils import check_random_state

class PseudoScoreTreeBuilder:
    """
    Class designed to build a decision tree based on the pseudo-score test statistics criterion,
    typically used in recurrent events data analysis.
    """

    TREE_UNDEFINED = -1  # Placeholder

    def __init__(self, max_depth=None, min_ids_split=2, min_ids_leaf=1,
                 max_features=None, max_thresholds=None, min_impurity_decrease=0,
                 random_state=None):
        self.max_depth = max_depth
        self.min_ids_split = min_ids_split
        self.min_ids_leaf = min_ids_leaf
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.min_impurity_decrease = min_impurity_decrease
        self.random_state = check_random_state(random_state)

    def split_indices(self, X_column, threshold, criterion, start, end):
        """Efficiently splits the data based on the given threshold for a specific feature column."""
        left_indices = np.where(X_column <= threshold)[0]
        right_indices = np.where(X_column > threshold)[0]

        # Convert local indices to global indices
        left_indices = np.arange(start, end)[left_indices]
        right_indices = np.arange(start, end)[right_indices]

        return left_indices, right_indices

    def _split(self, X, criterion, start, end):
        best_split = {
            'feature_index': None,
            'threshold': None,
            'improvement': -np.inf
        }
    
        # 가능한 스플릿 후보들을 저장하기 위한 리스트
        potential_splits = []

        n_features = X.shape[1]

        for feature_index in range(n_features):
            unique_thresholds = np.unique(X[start:end, feature_index])
            if len(unique_thresholds) <= 1:
                continue

            if self.max_thresholds and len(unique_thresholds) > self.max_thresholds:
                unique_thresholds = self.random_state.choice(unique_thresholds, self.max_thresholds, replace=False)

            for threshold in unique_thresholds:
                criterion.update(feature_index, threshold)
                improvement = criterion.proxy_impurity_improvement()

                left_indices, right_indices = self.split_indices(X[start:end, feature_index], threshold, criterion, start, end)

                # Ensure that both child nodes will have at least min_ids_leaf samples
                if len(left_indices) < self.min_ids_leaf or len(right_indices) < self.min_ids_leaf:
                    continue

                # self.min_impurity_decrease보다 큰 모든 스플릿 후보들을 저장
                if improvement > self.min_impurity_decrease:
                    potential_splits.append({
                        'feature_index': feature_index,
                        'threshold': threshold,
                        'improvement': improvement
                    })

                if improvement > best_split['improvement']:
                    best_split = {
                        'feature_index': feature_index,
                        'threshold': threshold,
                        'improvement': improvement
                    }

        return best_split, potential_splits



    
    def _build(self, X, y, criterion, depth=0, start=0, end=None):
        if end is None:
            end = X.shape[0]

        ids = y[start:end, 0]
        unique_ids = np.unique(ids)

        riskset_counter = RisksetCounter(ids, y[start:end, 1], y[start:end, 2], y[start:end, 3])
        node_value = criterion.node_value_from_riskset(riskset_counter)
        node_unique_times = riskset_counter.all_unique_times.tolist()
        node_value = node_value[:len(node_unique_times)]

        # Check depth and minimum ids required for split
        if self.max_depth is not None and depth >= self.max_depth:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': node_unique_times,
                'ids': unique_ids.tolist()
            }

        if len(unique_ids) < self.min_ids_split:
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': node_unique_times,
                'ids': unique_ids.tolist()
            }

        best_split, potential_splits = self._split(X, criterion, start, end)

        for split in [best_split] + potential_splits:
            if split['threshold'] is None:
                continue

            left_indices, right_indices = self.split_indices(X[start:end, split['feature_index']], split['threshold'], criterion, start, end)

            # Check if there are enough unique ids in both left and right children after the split
            if len(np.unique(ids[left_indices])) >= self.min_ids_leaf and len(np.unique(ids[right_indices])) >= self.min_ids_leaf:
                best_split = split
                break
        else:  # No valid split found
            return {
                'feature': None,
                'threshold': None,
                'left_child': None,
                'right_child': None,
                'node_value': node_value,
                'unique_times': node_unique_times,
                'ids': unique_ids.tolist()
            }

        left_child = self._build(X[left_indices], y[left_indices], criterion, depth=depth+1)
        right_child = self._build(X[right_indices], y[right_indices], criterion, depth=depth+1)

        return {
            'feature': best_split['feature_index'],
            'threshold': best_split['threshold'],
            'left_child': left_child,
            'right_child': right_child,
            'node_value': node_value,
            'unique_times': node_unique_times,
            'ids': unique_ids.tolist()
        }


    def build(self, X, ids, time_start, time_stop, event):
        """
        The main method to invoke the tree building process.
        Initializes the pseudo-score criterion using the input data and constructs the tree using the _build method.
        """
        n_samples, n_features = X.shape
        y = np.c_[ids, time_start, time_stop, event]

        unique_times = np.unique(np.concatenate([time_start, time_stop]))
        criterion = PseudoScoreCriterion(n_outputs=1, n_samples=n_samples,
                                         unique_times=unique_times, x=X, ids=ids,
                                         time_start=time_start, time_stop=time_stop, event=event)

        tree = self._build(X, y, criterion)
        return tree


In [30]:
data

Unnamed: 0.1,Unnamed: 0,id,start,stop,event,group,x1,gender
0,1,1,0,1,1,0,-1.93,1
1,2,1,1,22,1,0,-1.93,1
2,3,1,22,23,1,0,-1.93,1
3,4,1,23,57,1,0,-1.93,1
4,5,1,57,112,0,0,-1.93,1
...,...,...,...,...,...,...,...,...
495,496,100,119,123,1,1,-0.93,1
496,497,100,123,124,1,1,-0.93,1
497,498,100,124,134,1,1,-0.93,1
498,499,100,134,136,1,1,-0.93,1


In [31]:
ids = data['id'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
x = data[['group','x1','gender']].values

In [32]:
x

array([[ 0.  , -1.93,  1.  ],
       [ 0.  , -1.93,  1.  ],
       [ 0.  , -1.93,  1.  ],
       ...,
       [ 1.  , -0.93,  1.  ],
       [ 1.  , -0.93,  1.  ],
       [ 1.  , -0.93,  1.  ]])

In [33]:
# Initialize and build the tree using PseudoScoreTreeBuilder
tree_builder = PseudoScoreTreeBuilder(max_depth=3, random_state=1190)

In [34]:
tree_df = tree_builder.build(x, ids, time_start, time_stop, event)

# Display the tree dataframe
tree_df

  term = (self.riskset_left.n_events / self.riskset_left.n_at_risk) - (self.riskset_right.n_events / self.riskset_right.n_at_risk)
  res_ij = np.where(yVec > 0, y_i_tj / yVec * (n_i_tj - dLambda), 0)


{'feature': 1,
 'threshold': 1.75,
 'left_child': {'feature': 1,
  'threshold': 0.18,
  'left_child': {'feature': 1,
   'threshold': 0.17,
   'left_child': {'feature': None,
    'threshold': None,
    'left_child': None,
    'right_child': None,
    'node_value': [0.01818181814876033,
     0.05454545444628099,
     0.07272727259504132,
     0.09090909074380164,
     0.10909090889256197,
     0.1272727270413223,
     0.14545454519008263,
     0.16363636333884296,
     0.1999999996363636,
     0.21818181778512394,
     0.23636363593388426,
     0.2545454540826446,
     0.27272727223140497,
     0.3090909085289256,
     0.3454545448264463,
     0.3818181811239669,
     0.4181818174214876,
     0.43636363557024793,
     0.4727272718677686,
     0.5272727263140495,
     0.5454545444628098,
     0.5636363626115701,
     0.5818181807603304,
     0.5999999989090907,
     0.618181817057851,
     0.6363636352066113,
     0.6545454533553716,
     0.6909090896528923,
     0.7090909078016526,
     

### RecurrentTree V2.

In [30]:
from sklearn.base import BaseEstimator

class RecurrentTree(BaseEstimator):
    def __init__(self, max_depth=None, min_ids_split=2, min_ids_leaf=1, 
                 max_features=None, max_thresholds=None, min_impurity_decrease=0,
                 random_state=None):
        """
        Constructor of the class
        Initializes the tree's hyperparameters and settings
        """
        self.max_depth = max_depth
        self.min_ids_split = min_ids_split
        self.min_ids_leaf = min_ids_leaf
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        self.min_impurity_decrease = min_impurity_decrease
        self.random_state = random_state
        self.tree_ = None

    def fit(self, X, ids, time_start, time_stop, event):
        """
        Trains the recurrent tree using the input data
        """
        X = np.array(X)
        ids = np.array(ids)
        time_start = np.array(time_start)
        time_stop = np.array(time_stop)
        event = np.array(event)

        # Use the PseudoScoreTreeBuilder to build the tree
        builder = PseudoScoreTreeBuilder(
            max_depth=self.max_depth,
            min_ids_split=self.min_ids_split,
            min_ids_leaf=self.min_ids_leaf,
            max_features=self.max_features,
            max_thresholds=self.max_thresholds,
            min_impurity_decrease=self.min_impurity_decrease,
            random_state=self.random_state
        )
        self.tree_ = builder.build(X, ids, time_start, time_stop, event)
        return self

    def get_tree(self):
        """Return the tree as a dictionary."""
        return self.tree_

    def traverse_tree_for_id(self, X_id_samples, node):
        """
        Traverse the tree for a specific ID based on its samples.

        Args:
        - X_id_samples (list of arrays): The samples corresponding to a specific ID.
        - node (dict): The current node being evaluated in the tree.

        Returns:
        - node (dict): The terminal node for the specific ID.
        """
        if node["feature"] is None:  # Terminal node
            return node

        # Traverse the tree for each sample and collect the terminal nodes
        terminal_nodes = []
        for sample in X_id_samples:
            if sample[node["feature"]] <= node["threshold"]:
                terminal_nodes.append(self.traverse_tree_for_id([sample], node["left_child"]))
            else:
                terminal_nodes.append(self.traverse_tree_for_id([sample], node["right_child"]))

        # Check if all samples lead to the same terminal node
        first_terminal = terminal_nodes[0]
        if all(node == first_terminal for node in terminal_nodes):
            return first_terminal

        # If samples lead to different terminal nodes, it's ambiguous. For simplicity, return the first terminal node.
        # In a real-world scenario, this might need more sophisticated handling.
        return first_terminal

    def predict_mean_function(self, X, ids):
        """
        Predict the node_value of the terminal node for given samples.
        """

        # Ensure X is a list of samples
        X = np.array(X)

        mean_function_predictions = {}

        for sample_id in ids:
            samples_for_id = [X[i] for i, uid in enumerate(ids) if uid == sample_id]
            terminal_node_for_id = self.traverse_tree_for_id(samples_for_id, self.tree_)

            mean_function_predictions[sample_id] = terminal_node_for_id["node_value"]

        return mean_function_predictions

    def predict_rate_function(self, X, ids):
        """
        Predict the rate function as the difference between unique time points for the terminal node.
        """
        
        if np.isscalar(ids):
            ids = [ids]
            
        mean_function_predictions = self.predict_mean_function(X, ids)
    
        rate_function_predictions = {}
        for sample_id in ids:
            mean_function_values = mean_function_predictions[sample_id]

            # Calculate rate function as the difference between consecutive mean function values
            rate_function = np.diff(mean_function_values, prepend=mean_function_values[0])

            rate_function_predictions[sample_id] = rate_function

        return rate_function_predictions
    
    def _map_terminal_nodes(self, node, current_id=[0]):
        """
        Recursively traverse the tree and assign unique integers to each terminal node.
        """
        if node["feature"] is None:  # Terminal node
            if "id" not in node:
                node["id"] = current_id[0]
                current_id[0] += 1
            return

        self._map_terminal_nodes(node["left_child"], current_id)
        self._map_terminal_nodes(node["right_child"], current_id)

    def apply(self, X, ids=None):
        """Return the index of the leaf that each unique ID is predicted as."""
        X = np.array(X, dtype=np.float32)
        if ids is None:
            ids = np.array([i for i in range(X.shape[0])])
        else:
            ids = np.array(ids)

        terminal_nodes = {}
        self._map_terminal_nodes(self.tree_)  # Reset the mapping

        for sample_id in np.unique(ids):
            samples_for_id = [X[i] for i, uid in enumerate(ids) if uid == sample_id]
            terminal_node_for_id = self.traverse_tree_for_id(samples_for_id, self.tree_)
            terminal_nodes[sample_id] = terminal_node_for_id["id"]

        return terminal_nodes


In [31]:
x = data[['group','x1','gender']].values

In [32]:
# 2. RecurrentTree 학습 및 예측
tree_model = RecurrentTree(max_depth=3, min_ids_leaf=5, random_state=1190)
tree_model.fit(x, ids, time_start, time_stop, event)

  term = (self.riskset_left.n_events / self.riskset_left.n_at_risk) - (self.riskset_right.n_events / self.riskset_right.n_at_risk)
  res_ij = np.where(yVec > 0, y_i_tj / yVec * (n_i_tj - dLambda), 0)


In [33]:
%pip install graphviz

Note: you may need to restart the kernel to use updated packages.


In [34]:
tree=tree_model.get_tree()

In [35]:
predict_mean_function=tree_model.predict_mean_function(x,ids=ids)
print("ID 1 predicted mean function:",predict_mean_function[1])
print("ID 2 predicted mean function:",predict_mean_function[2])
print("ID 26 predicted mean function:",predict_mean_function[26])
print("ID 27 predicted mean function:",predict_mean_function[27])

ID 1 predicted mean function: [0.03999999984, 0.07999999968, 0.11999999952000001, 0.15999999936, 0.1999999992, 0.23999999904, 0.27999999888, 0.31999999872, 0.35999999856000003, 0.39999999840000006, 0.47999999808000005, 0.5599999977600001, 0.5999999976000001, 0.6399999974400001, 0.6799999972800002, 0.7199999971200002, 0.7599999969600002, 0.7999999968000002, 0.8399999966400002, 0.8799999964800003, 0.9199999963200003, 0.9599999961600003, 0.9999999960000003, 1.0399999958400004, 1.0799999956800004, 1.1199999955200004, 1.1999999952000004, 1.2799999948800005, 1.3199999947200005, 1.3599999945600005, 1.4099999943100006, 1.4599999940600006, 1.5099999938100006, 1.5599999935600006, 1.7099999928100007, 1.7599999925600007, 1.8099999923100007, 1.8599999920600008, 1.9099999918100008, 1.9599999915600008, 2.009999991310001, 2.081428562228368, 2.1528571331467354, 2.2242857040651027, 2.29571427498347, 2.3671428459018373, 2.4385714168202046, 2.509999987738572, 2.509999987738572]
ID 2 predicted mean functio

In [36]:
tree_model.apply(x,ids=ids)

{1: 0,
 2: 1,
 3: 4,
 4: 1,
 5: 4,
 6: 6,
 7: 1,
 8: 3,
 9: 4,
 10: 6,
 11: 0,
 12: 0,
 13: 0,
 14: 0,
 15: 6,
 16: 4,
 17: 6,
 18: 1,
 19: 0,
 20: 7,
 21: 5,
 22: 5,
 23: 7,
 24: 1,
 25: 5,
 26: 6,
 27: 5,
 28: 5,
 29: 5,
 30: 1,
 31: 0,
 32: 0,
 33: 5,
 34: 0,
 35: 5,
 36: 2,
 37: 3,
 38: 1,
 39: 7,
 40: 0,
 41: 7,
 42: 6,
 43: 5,
 44: 5,
 45: 1,
 46: 0,
 47: 5,
 48: 1,
 49: 1,
 50: 1,
 51: 0,
 52: 1,
 53: 0,
 54: 5,
 55: 0,
 56: 0,
 57: 2,
 58: 5,
 59: 0,
 60: 1,
 61: 6,
 62: 0,
 63: 0,
 64: 5,
 65: 1,
 66: 1,
 67: 1,
 68: 2,
 69: 0,
 70: 1,
 71: 0,
 72: 5,
 73: 6,
 74: 3,
 75: 0,
 76: 1,
 77: 1,
 78: 7,
 79: 7,
 80: 7,
 81: 7,
 82: 5,
 83: 7,
 84: 5,
 85: 7,
 86: 1,
 87: 1,
 88: 7,
 89: 3,
 90: 0,
 91: 1,
 92: 0,
 93: 5,
 94: 2,
 95: 5,
 96: 7,
 97: 6,
 98: 2,
 99: 5,
 100: 0}

In [37]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter
        
        if node is None:
            return
        
        node_name = f"node{node_counter}"
        node_counter += 1
        
        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = "\n".join([f"t{idx}: {value:.2f}" for idx, value in enumerate(node['node_value'])])
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")
        
        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph


In [38]:
dot = visualize_tree_simple(tree)
dot.view()

'Digraph.gv.pdf'

In [39]:
def visualize_tree_with_data(tree):
    """
    Visualize the tree using graphviz.
    This function displays the structure of the tree, the threshold values for each node,
    and the unique IDs at leaf nodes.
    """
    import graphviz

    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter
        
        if node is None:
            return
        
        node_name = f"node{node_counter}"
        node_counter += 1
        
        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = f"Unique IDs: {node['ids']}"
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")
        
        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph




In [40]:
dot = visualize_tree_with_data(tree)
dot.view()

'Digraph.gv.pdf'

In [41]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter
        
        if node is None:
            return
        
        node_name = f"node{node_counter}"
        node_counter += 1
        
        # If the current node is a leaf node
        if node["feature"] is None:
            graph.node(node_name, label="Leaf Node", shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")
        
        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph
    

In [42]:
dot = visualize_tree_simple(tree)
dot.view()

'Digraph.gv.pdf'

### Mission Clear

In [43]:
x = data[['group','gender']].values
# 2. RecurrentTree 학습 및 예측
tree_model = RecurrentTree(max_depth=2, random_state=1190)
tree_model.fit(x, ids, time_start, time_stop, event)

In [44]:
tree=tree_model.get_tree()
tree

{'feature': 0,
 'threshold': 0,
 'left_child': {'feature': 1,
  'threshold': 0,
  'left_child': {'feature': None,
   'threshold': None,
   'left_child': None,
   'right_child': None,
   'node_value': [0.032258064412070755,
    0.06451612882414151,
    0.12903225764828302,
    0.16129032206035376,
    0.1935483864724245,
    0.22580645088449525,
    0.3225806441207075,
    0.35483870853277827,
    0.387096772944849,
    0.41935483735691975,
    0.4516129017689905,
    0.48387096618106124,
    0.516129030593132,
    0.6129032238293443,
    0.645161288241415,
    0.6774193526534859,
    0.7096774170655566,
    0.7419354814776274,
    0.8387096747138397,
    0.8709677391259105,
    0.935483867950052,
    1.0322580611862642,
    1.0645161255983349,
    1.1290322544224765,
    1.1612903188345471,
    1.1935483832466178,
    1.2580645120707594,
    1.3548387053069717,
    1.3870967697190424,
    1.419354834131113,
    1.4516128985431838,
    1.4838709629552544,
    1.548387091779396,
    1.58

In [45]:
predict_mean_function=tree_model.predict_mean_function(x,ids)
predict_mean_function[64]

[0.032258064412070755,
 0.06451612882414151,
 0.12903225764828302,
 0.16129032206035376,
 0.1935483864724245,
 0.22580645088449525,
 0.3225806441207075,
 0.35483870853277827,
 0.387096772944849,
 0.41935483735691975,
 0.4516129017689905,
 0.48387096618106124,
 0.516129030593132,
 0.6129032238293443,
 0.645161288241415,
 0.6774193526534859,
 0.7096774170655566,
 0.7419354814776274,
 0.8387096747138397,
 0.8709677391259105,
 0.935483867950052,
 1.0322580611862642,
 1.0645161255983349,
 1.1290322544224765,
 1.1612903188345471,
 1.1935483832466178,
 1.2580645120707594,
 1.3548387053069717,
 1.3870967697190424,
 1.419354834131113,
 1.4516128985431838,
 1.4838709629552544,
 1.548387091779396,
 1.5806451561914667,
 1.6451612850156083,
 1.7419354782518206,
 1.8064516070759622,
 1.8709677359001038,
 1.9354838647242454,
 1.999999993548387,
 2.032258057960458,
 2.12903225119667,
 2.1935483800208115,
 2.2258064444328824,
 2.2580645088449534,
 2.2903225732570243,
 2.3548387020811656,
 2.38709676649

In [46]:
predict_mean_function

{1: [0.04166666649305555,
  0.1666666659722222,
  0.2499999989583333,
  0.29166666545138886,
  0.3333333319444444,
  0.3749999984375,
  0.41666666493055554,
  0.49999999791666666,
  0.6249999973958333,
  0.6666666638888888,
  0.749999996875,
  0.7916666633680555,
  0.833333329861111,
  0.8749999963541665,
  0.916666662847222,
  0.9999999958333331,
  1.1666666618055552,
  1.2083333282986108,
  1.2499999947916665,
  1.3333333277777775,
  1.374999994270833,
  1.4166666607638887,
  1.4999999937499997,
  1.5416666602430553,
  1.583333326736111,
  1.6249999932291666,
  1.7499999927083332,
  1.7916666592013888,
  1.8333333256944444,
  1.9166666586805554,
  1.9999999916666664,
  2.041666658159722,
  2.0833333246527777,
  2.1249999911458333,
  2.166666657638889,
  2.249999990625,
  2.374999990104167,
  2.4166666565972226,
  2.458333323090278,
  2.499999989583334,
  2.5416666560763894,
  2.583333322569445,
  2.6249999890625006,
  2.6666666555555563,
  2.753623176916615,
  2.7971014375971444,
  2

In [47]:
tree_model.apply(x,ids)

{1: 9,
 2: 11,
 3: 9,
 4: 9,
 5: 9,
 6: 11,
 7: 11,
 8: 9,
 9: 9,
 10: 11,
 11: 9,
 12: 9,
 13: 9,
 14: 11,
 15: 11,
 16: 9,
 17: 11,
 18: 11,
 19: 9,
 20: 11,
 21: 9,
 22: 9,
 23: 11,
 24: 11,
 25: 9,
 26: 10,
 27: 8,
 28: 8,
 29: 8,
 30: 8,
 31: 8,
 32: 8,
 33: 8,
 34: 8,
 35: 8,
 36: 8,
 37: 10,
 38: 10,
 39: 10,
 40: 8,
 41: 10,
 42: 10,
 43: 8,
 44: 8,
 45: 10,
 46: 8,
 47: 8,
 48: 8,
 49: 8,
 50: 10,
 51: 8,
 52: 8,
 53: 8,
 54: 8,
 55: 8,
 56: 8,
 57: 10,
 58: 8,
 59: 10,
 60: 10,
 61: 10,
 62: 10,
 63: 8,
 64: 8,
 65: 8,
 66: 8,
 67: 10,
 68: 8,
 69: 10,
 70: 10,
 71: 8,
 72: 8,
 73: 10,
 74: 10,
 75: 10,
 76: 9,
 77: 9,
 78: 11,
 79: 11,
 80: 11,
 81: 11,
 82: 9,
 83: 11,
 84: 9,
 85: 11,
 86: 9,
 87: 11,
 88: 11,
 89: 11,
 90: 9,
 91: 9,
 92: 11,
 93: 9,
 94: 11,
 95: 9,
 96: 11,
 97: 11,
 98: 11,
 99: 9,
 100: 11}

In [48]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter
        
        if node is None:
            return
        
        node_name = f"node{node_counter}"
        node_counter += 1
        
        # If the current node is a leaf node
        if node["feature"] is None:
            graph.node(node_name, label="Leaf Node", shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")
        
        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph
    

In [49]:
dot = visualize_tree_simple(tree)
dot.view()

'Digraph.gv.pdf'

In [50]:
import graphviz

def visualize_tree_simple(tree):
    """
    Visualize the tree using graphviz.
    This function only displays the structure of the tree and the threshold values for each node.
    """
    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter
        
        if node is None:
            return
        
        node_name = f"node{node_counter}"
        node_counter += 1
        
        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = "\n".join([f"t{idx}: {value:.2f}" for idx, value in enumerate(node['node_value'])])
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")
        
        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph


In [51]:
dot = visualize_tree_simple(tree)
dot.view()

'Digraph.gv.pdf'

In [59]:
def visualize_tree_with_data(tree):
    """
    Visualize the tree using graphviz.
    This function displays the structure of the tree, the threshold values for each node,
    and the unique IDs at leaf nodes.
    """
    import graphviz

    graph = graphviz.Digraph()

    def traverse_tree(node, parent_name=None, decision=None):
        nonlocal node_counter
        
        if node is None:
            return
        
        node_name = f"node{node_counter}"
        node_counter += 1
        
        # If the current node is a leaf node
        if node["feature"] is None:
            leaf_info = f"Unique IDs: {node['ids']}"
            graph.node(node_name, label=leaf_info, shape="box")
        else:
            decision_info = f"Feature {node['feature']} <= {node['threshold']:.2f}"
            graph.node(node_name, label=decision_info)
            
            traverse_tree(node['left_child'], node_name, decision="True")
            traverse_tree(node['right_child'], node_name, decision="False")
        
        # Connect the parent node to the current node
        if parent_name:
            graph.edge(parent_name, node_name, label=decision)

    node_counter = 0
    traverse_tree(tree)

    return graph

In [52]:
dot = visualize_tree_with_data(tree)
dot.view()

'Digraph.gv.pdf'

### RecurrentRandomForest

In [53]:
import numpy as np
from numbers import Integral, Real
from sklearn.utils import check_random_state
from numpy.random import RandomState

def check_random_state(seed):
    """
    Check if seed is a valid random state.
    """
    if seed is None or isinstance(seed, int):
        return np.random.default_rng(seed)
    elif isinstance(seed, np.random.Generator):
        return seed
    else:
        raise ValueError(f"Invalid seed: {seed}")

def _get_n_ids_bootstrap(n_ids, max_ids):
    """
    Modified for recurrent events. Get the number of IDs in a bootstrap sample.
    """
    if max_ids is None:
        return n_ids

    if isinstance(max_ids, Integral):
        if max_ids > n_ids:
            msg = "`max_samples` must be <= n_ids={} but got value {}"
            raise ValueError(msg.format(n_ids, max_ids))
        return max_ids

    if isinstance(max_ids, Real):
        return max(round(n_ids * max_ids), 1)

def _generate_sampled_ids(random_state, unique_ids, max_ids):
    """
    Generate bootstrap sample indices based on unique IDs.
    """
    # Calculate the number of IDs to be sampled using the _get_n_ids_bootstrap function
    n_ids_bootstrap = _get_n_ids_bootstrap(len(unique_ids), max_ids)

    # Create a random instance with the given random_state
    random_instance = check_random_state(random_state)

    # Randomly select n_ids_bootstrap IDs from the unique_ids with replacement
    sampled_ids_indices = random_instance.choice(len(unique_ids), n_ids_bootstrap, replace=True)
    
    # Get the actual IDs using the indices
    sampled_ids = unique_ids[sampled_ids_indices]
    
    return sampled_ids

def _generate_unsampled_ids(unique_ids, sampled_ids):
    """
    Determine unsampled unique IDs from the entire set of IDs.
    """
    # 중복 제거된 sampled_ids
    unique_sampled_ids = np.unique(sampled_ids)
    
    # Find unsampled unique IDs
    unsampled_unique_ids = np.setdiff1d(unique_ids, unique_sampled_ids)
    return unsampled_unique_ids


from warnings import catch_warnings, simplefilter
from sklearn.utils.class_weight import compute_sample_weight

def _parallel_build_trees(
    tree,
    bootstrap,
    X,
    y,
    tree_idx,
    n_trees,
    verbose=0,
    n_ids_bootstrap=None,
    random_state=None
):
    """
    Private function used to fit a single tree in parallel for recurrent events.
    """
    # Extract necessary data from y
    ids = y['id']
    time_start = y['time_start']
    time_stop = y['time_stop']
    event = y['event']

    if verbose > 1:
        print("building tree %d of %d" % (tree_idx + 1, n_trees))

    if bootstrap:
        unique_ids = np.unique(ids)
        n_ids = len(unique_ids)

        # Generate bootstrap samples using IDs
        sampled_ids = _generate_sampled_ids(
            RandomState(random_state), unique_ids, n_ids_bootstrap
        )

        # Expand sampled IDs to all their associated events
        indices = np.where(np.isin(ids, sampled_ids))[0]
        
        tree.fit(X[indices], ids[indices], time_start[indices], time_stop[indices], event[indices])
    else:
        tree.fit(X, ids, time_start, time_stop, event)
    
    return tree





In [27]:
ids = np.array([1, 1, 2, 3, 3, 3, 4, 5, 5])
unique_ids = np.unique(ids)

In [28]:
n_ids_bootstrap = _get_n_ids_bootstrap(len(unique_ids), max_ids=1.0)
print("Number of IDs for bootstrap:", n_ids_bootstrap)

Number of IDs for bootstrap: 5


In [29]:
sampled_ids = _generate_sampled_ids(1190, unique_ids, max_ids=1.0)
sampled_ids

array([4, 3, 3, 4, 2])

In [30]:
unsampled_unique_ids = _generate_unsampled_ids(unique_ids, sampled_ids)

In [31]:
unsampled_unique_ids

array([1, 5])

In [32]:
%pip install dill

Note: you may need to restart the kernel to use updated packages.


In [54]:
def recurrent_concordance_index_score(predictions, X, event, ids):
    """
    Computes the modified C-statistic for recurrent events.
    
    :param predictions: Ensemble predictions for each observation.
    :param X: Feature data.
    :param event: Array indicating event occurrence.
    :param ids: Array of IDs for each observation.
    :return: C-statistic and prediction error.
    """
    unique_ids = np.unique(ids)
    n_unique_ids = len(unique_ids)

    # Calculate total events for each ID
    total_events = {uid: np.sum(event[ids == uid]) for uid in unique_ids}

    id_to_avg_prediction = {}
    for uid in unique_ids:
        uid_indices = np.where(ids == uid)[0]
        uid_predictions = [predictions[i] for i in uid_indices if i in predictions]

        # Check if the uid_predictions list is empty
        if not uid_predictions:
            continue

        # Calculate the average for each element
        max_length = max(map(len, uid_predictions))
        avg_prediction = []
        for i in range(max_length):
            avg_prediction.append(np.mean([pred[i] for pred in uid_predictions if i < len(pred)]))

        id_to_avg_prediction[uid] = avg_prediction

    concordant_pairs = 0
    permissible_pairs = 0

    for i in range(n_unique_ids):
        for j in range(i+1, n_unique_ids):
            uid_i = unique_ids[i]
            uid_j = unique_ids[j]
            
            if uid_i not in id_to_avg_prediction or uid_j not in id_to_avg_prediction:
                continue

            if total_events[uid_i] > total_events[uid_j]:
                permissible_pairs += 1
                if id_to_avg_prediction[uid_i][-1] > id_to_avg_prediction[uid_j][-1]:
                    concordant_pairs += 1

    c_index = concordant_pairs / permissible_pairs if permissible_pairs > 0 else 0
    prediction_error = 1 - c_index
    return c_index, prediction_error


In [55]:
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from joblib import Parallel, delayed
from numpy.random import RandomState
from sklearn.utils import check_random_state, check_array
from sklearn.exceptions import DataConversionWarning
from scipy.sparse import issparse
MAX_INT = np.iinfo(np.int32).max


from sklearn.utils import check_random_state
import numpy as np

class RecurrentRandomForest(BaseEstimator):
    """
    A Random Forest model designed for recurrent event data.
    """
    def __init__(self, n_estimators=100, max_depth=None, min_ids_split=2,
                 min_ids_leaf=1, bootstrap=True, oob_score=False, n_jobs=None,
                 random_state=None, verbose=0, warm_start=False, max_ids=None,
                 min_impurity_decrease=0.0, max_features=None, max_thresholds=None):
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.min_ids_split = min_ids_split
        self.min_ids_leaf = min_ids_leaf
        self.bootstrap = bootstrap
        self.oob_score = oob_score
        self.n_jobs = n_jobs
        self.verbose = verbose
        self.warm_start = warm_start
        self.max_ids = max_ids
        self.min_impurity_decrease = min_impurity_decrease
        self.max_features = max_features
        self.max_thresholds = max_thresholds
        
        # Initialize the random state for the forest
        self.random_state = check_random_state(random_state)
        
        # Create the estimators using the updated random states
        self.estimators_ = [self._make_estimator() for _ in range(self.n_estimators)]

    def _make_estimator(self):
        """
        Constructs a new instance of the 'RecurrentTree' with the specified hyperparameters.
        Allows for creating each tree with a different 'random_state' for randomness.
        """
        # Generate a new random state for each tree based on the forest's random state
        tree_random_state = self.random_state.randint(np.iinfo(np.int32).max)

        return RecurrentTree(
            max_depth=self.max_depth,
            min_ids_split=self.min_ids_split,
            min_ids_leaf=self.min_ids_leaf,
            random_state=tree_random_state,  # Pass the generated random state for the tree
            min_impurity_decrease=self.min_impurity_decrease,
            max_features=self.max_features,
            max_thresholds=self.max_thresholds
        )

    def fit(self, X, y):
        """
        Build the recurrent random forest.
        """
        X = self._validate_data(X)  # This will validate X and ensure it's an array.
        self.n_features_in_ = X.shape[1]  # Set the number of features attribute.

        # Convert y to the required format
        y_converted = {
            'id': y['id'],
            'time_start': y['time_start'],
            'time_stop': y['time_stop'],
            'event': y['event']
        }

        # Get the number of bootstrap samples
        n_samples_bootstrap = _get_n_ids_bootstrap(len(np.unique(y['id'])), self.max_ids)

        # Train each tree in parallel
        self.estimators_ = Parallel(n_jobs=self.n_jobs)(
            delayed(_parallel_build_trees)(
                tree=tree,
                bootstrap=self.bootstrap,
                X=X,
                y=y_converted,
                tree_idx=i,
                n_trees=self.n_estimators,
                verbose=self.verbose,
                n_ids_bootstrap=n_samples_bootstrap,
                random_state=tree.random_state  # Pass the random state for each tree here
            ) for i, tree in enumerate(self.estimators_)
        )

        # Calculate OOB score and attributes if needed
        if self.oob_score:
            self._set_oob_score_and_attributes(X, y_converted)

        return self
    
    def _set_oob_score_and_attributes(self, X, y):
        """
        Calculates the out-of-bag (OOB) scores using the ensemble's predictions for the training data samples 
        that were not seen during the training of a given tree.
        Also sets the 'oob_prediction_' and 'oob_score_' attributes of the class.
        """
        X = self._validate_data(X)
        ids = y['id']
        event = y['event']

        all_predictions = {}

        for estimator in self.estimators_:
            sampled_ids = _generate_sampled_ids(estimator.random_state, np.unique(ids), self.max_ids)
            unsampled_ids = _generate_unsampled_ids(np.unique(ids), sampled_ids)

            # Refitting the tree using only the unsampled data.
            estimator.fit(X[unsampled_ids, :], y['id'][unsampled_ids], y['time_start'][unsampled_ids], y['time_stop'][unsampled_ids], y['event'][unsampled_ids])

            # Making predictions using the predict_mean_function
            p_estimator_result_all = estimator.predict_mean_function(x, ids)
            p_estimator_result = {uid: p_estimator_result_all[uid] for uid in unsampled_ids if uid in p_estimator_result_all}
        
            for uid, pred in p_estimator_result.items():
                if uid not in all_predictions:
                    all_predictions[uid] = []
                all_predictions[uid].append(pred)

        # Averaging the predictions
        averaged_predictions = {}
        for uid, preds in all_predictions.items():
            averaged_predictions[uid] = self._pad_and_average_predictions(preds, len(self.estimators_))

        self.oob_prediction_ = averaged_predictions

        # Calculating the C-index and Prediction Error using an external function
        self.oob_score_, self.oob_prediction_error_ = recurrent_concordance_index_score(averaged_predictions, X, event, ids)


        """Validate input data('X') to ensure it's in the correct format and meets the necessary conditions for processing."""
    def _validate_data(self, X, accept_sparse=False, ensure_min_samples=1):
        """Validate input data('X') to ensure it's in the correct format and meets the necessary conditions for processing."""
        return check_array(X, accept_sparse=accept_sparse, ensure_min_samples=ensure_min_samples)
    
    def _validate_X_predict(self, X):
        """Validate X whenever one tries to predict."""
        X = check_array(X)
        if X.shape[1] != self.n_features_in_:
            raise ValueError("Number of features of the model must match the input. Model n_features is {} and input n_features is {}."
                             .format(self.n_features_in_, X.shape[1]))
        return X

    def _pad_and_average_predictions(self, all_predictions_for_id, n_trees):
        """
        Pad the predictions to the length of the longest prediction and then average them.
        """
        max_length = max(map(len, all_predictions_for_id))

        # Pad each prediction to the maximum length
        padded_predictions = []
        for prediction in all_predictions_for_id:
            if len(prediction) < max_length:
                pad_length = max_length - len(prediction)
                padded_prediction = np.concatenate([prediction, [prediction[-1]] * pad_length])
            else:
                padded_prediction = prediction
            padded_predictions.append(padded_prediction)

        # Average the padded predictions
        average_prediction = np.mean(padded_predictions, axis=0)
        return average_prediction.tolist()

    def predict_mean_function(self, X, ids):
        X = self._validate_X_predict(X)
    
        # Get predictions from each tree
        all_predictions = [tree.predict_mean_function(X, ids) for tree in self.estimators_]

        # Average the predictions for each unique ID
        averaged_predictions = {}
        for uid in np.unique(ids):
            uid_predictions = [tree_preds[uid] for tree_preds in all_predictions if uid in tree_preds]
            averaged_predictions[uid] = self._pad_and_average_predictions(uid_predictions, len(self.estimators_))
    
        return averaged_predictions

    def predict_rate_function(self, X, ids):
        X = self._validate_X_predict(X)
    
        # Get predictions from each tree
        all_predictions = [tree.predict_rate_function(X, ids) for tree in self.estimators_]

        # Average the predictions for each unique ID
        averaged_predictions = {}
        for uid in np.unique(ids):
            uid_predictions = [tree_preds[uid] for tree_preds in all_predictions if uid in tree_preds]
            averaged_predictions[uid] = self._pad_and_average_predictions(uid_predictions, len(self.estimators_))
    
        return averaged_predictions

In [56]:
x = data[['group','gender']].values

In [57]:
ids = data['id'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values

In [58]:
rrf = RecurrentRandomForest(n_estimators=100, max_depth=2, random_state=42, oob_score=True, n_jobs=6)
y = {
    'id': ids,
    'time_start': time_start,
    'time_stop': time_stop,
    'event': event
}
rrf.fit(x,y)

  term = (self.riskset_left.n_events / self.riskset_left.n_at_risk) - (self.riskset_right.n_events / self.riskset_right.n_at_risk)
  res_ij = np.where(yVec > 0, y_i_tj / yVec * (n_i_tj - dLambda), 0)


In [59]:
rrf.estimators_

[RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=1608637542),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=1273642419),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=1935803228),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=787846414),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=996406378),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=1201263687),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=423734972),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=415968276),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=670094950),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=1914837113),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=669991378),
 RecurrentTree(max_depth=2, min_impurity_decrease=0.0, random_state=429389014),
 RecurrentTree(max_depth=2, min_imp

In [60]:
rrf_prediction=rrf.predict_mean_function(X=x,ids=ids)

In [61]:
rrf_prediction[1]

[0.13720711071024838,
 0.2632226631012831,
 0.3964122080729097,
 0.5313044947057949,
 0.6523634481472912,
 0.7854567837095128,
 0.9182448134065132,
 1.056373252585723,
 1.1947198346890022,
 1.3403559465866477,
 1.4869609228729574,
 1.6346254228943684,
 1.7865348722179355,
 1.9392542059838822,
 2.1008538813028492,
 2.2691559373880734,
 2.4418312589252924,
 2.6127526119824664,
 2.7846458264675795,
 2.9551028937465342,
 3.1389163816735306,
 3.312971933048343,
 3.468864785860618,
 3.624900495324529,
 3.7760473160303536,
 3.9243846118984766,
 4.006586990129565,
 4.078610797322281,
 4.121872700756278,
 4.158706032782113,
 4.172872698914056,
 4.185372698518221,
 4.188706031740445,
 4.188706031740445]

In [62]:
rrf_prediction[2]

[0.20507254723387627,
 0.3790736657816134,
 0.5379392538580605,
 0.6671540481889988,
 0.7202779365440972,
 0.7530043897233425,
 0.7751282790945515,
 0.7855842813315513,
 0.7958266083592501,
 0.8076084652815235,
 0.8204464882489882,
 0.8329082485881598,
 0.8454129382200689,
 0.8590251314449966,
 0.8699779840348714,
 0.8811094080485341,
 0.8929513992591986,
 0.9041703745961985,
 0.9150754971141462,
 0.9293063771674345,
 0.9415246309769014,
 0.9540920911058594,
 0.9648659004554498,
 0.980794471644215,
 0.9994174871631577,
 1.0217309786963436,
 1.0370047879084339,
 1.0563619303183716,
 1.0714571680754235,
 1.084290501056257,
 1.0934571674382014,
 1.0992905005979237,
 1.102623833820146,
 1.102623833820146]

In [63]:
rrf_prediction[26]

[0.20507254723387627,
 0.3790736657816134,
 0.5379392538580605,
 0.6671540481889988,
 0.7202779365440972,
 0.7530043897233425,
 0.7751282790945515,
 0.7855842813315513,
 0.7958266083592501,
 0.8076084652815235,
 0.8204464882489882,
 0.8329082485881598,
 0.8454129382200689,
 0.8590251314449966,
 0.8699779840348714,
 0.8811094080485341,
 0.8929513992591986,
 0.9041703745961985,
 0.9150754971141462,
 0.9293063771674345,
 0.9415246309769014,
 0.9540920911058594,
 0.9648659004554498,
 0.980794471644215,
 0.9994174871631577,
 1.0217309786963436,
 1.0370047879084339,
 1.0563619303183716,
 1.0714571680754235,
 1.084290501056257,
 1.0934571674382014,
 1.0992905005979237,
 1.102623833820146,
 1.102623833820146]

In [64]:
rrf_prediction[27]

[0.13720711071024838,
 0.2632226631012831,
 0.3964122080729097,
 0.5313044947057949,
 0.6523634481472912,
 0.7854567837095128,
 0.9182448134065132,
 1.056373252585723,
 1.1947198346890022,
 1.3403559465866477,
 1.4869609228729574,
 1.6346254228943684,
 1.7865348722179355,
 1.9392542059838822,
 2.1008538813028492,
 2.2691559373880734,
 2.4418312589252924,
 2.6127526119824664,
 2.7846458264675795,
 2.9551028937465342,
 3.1389163816735306,
 3.312971933048343,
 3.468864785860618,
 3.624900495324529,
 3.7760473160303536,
 3.9243846118984766,
 4.006586990129565,
 4.078610797322281,
 4.121872700756278,
 4.158706032782113,
 4.172872698914056,
 4.185372698518221,
 4.188706031740445,
 4.188706031740445]

In [65]:
rrf_prediction[100], len(rrf_prediction[100])

([0.20507254723387627,
  0.3790736657816134,
  0.5379392538580605,
  0.6671540481889988,
  0.7202779365440972,
  0.7530043897233425,
  0.7751282790945515,
  0.7855842813315513,
  0.7958266083592501,
  0.8076084652815235,
  0.8204464882489882,
  0.8329082485881598,
  0.8454129382200689,
  0.8590251314449966,
  0.8699779840348714,
  0.8811094080485341,
  0.8929513992591986,
  0.9041703745961985,
  0.9150754971141462,
  0.9293063771674345,
  0.9415246309769014,
  0.9540920911058594,
  0.9648659004554498,
  0.980794471644215,
  0.9994174871631577,
  1.0217309786963436,
  1.0370047879084339,
  1.0563619303183716,
  1.0714571680754235,
  1.084290501056257,
  1.0934571674382014,
  1.0992905005979237,
  1.102623833820146,
  1.102623833820146],
 34)

In [66]:
rrf_prediction

{1: [0.13720711071024838,
  0.2632226631012831,
  0.3964122080729097,
  0.5313044947057949,
  0.6523634481472912,
  0.7854567837095128,
  0.9182448134065132,
  1.056373252585723,
  1.1947198346890022,
  1.3403559465866477,
  1.4869609228729574,
  1.6346254228943684,
  1.7865348722179355,
  1.9392542059838822,
  2.1008538813028492,
  2.2691559373880734,
  2.4418312589252924,
  2.6127526119824664,
  2.7846458264675795,
  2.9551028937465342,
  3.1389163816735306,
  3.312971933048343,
  3.468864785860618,
  3.624900495324529,
  3.7760473160303536,
  3.9243846118984766,
  4.006586990129565,
  4.078610797322281,
  4.121872700756278,
  4.158706032782113,
  4.172872698914056,
  4.185372698518221,
  4.188706031740445,
  4.188706031740445],
 2: [0.20507254723387627,
  0.3790736657816134,
  0.5379392538580605,
  0.6671540481889988,
  0.7202779365440972,
  0.7530043897233425,
  0.7751282790945515,
  0.7855842813315513,
  0.7958266083592501,
  0.8076084652815235,
  0.8204464882489882,
  0.832908248

In [67]:
rrf.oob_score_

0.8

In [68]:
predictions = rrf.predict_mean_function(x,ids)

recurrent_concordance_index_score(predictions=predictions, X=x, event=event, ids=ids)

(0.638095238095238, 0.36190476190476195)

In [75]:
x = data[['group','x1','gender']].values
ids = data['id'].values
time_start = data['start'].values
time_stop = data['stop'].values
event = data['event'].values
rrf = RecurrentRandomForest(n_estimators=100, max_depth=3, random_state=1190, oob_score=True, n_jobs=6, verbose=1)
y = {
    'id': ids,
    'time_start': time_start,
    'time_stop': time_stop,
    'event': event
}
rrf.fit(x,y)

  term = (self.riskset_left.n_events / self.riskset_left.n_at_risk) - (self.riskset_right.n_events / self.riskset_right.n_at_risk)
  res_ij = np.where(yVec > 0, y_i_tj / yVec * (n_i_tj - dLambda), 0)


In [76]:
rrf.estimators_

[RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=1810984180),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=1517185480),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=907657181),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=1382377102),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=1268296320),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=423611028),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=495308102),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=1327217891),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=1883106677),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=1748776709),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=2103212226),
 RecurrentTree(max_depth=3, min_impurity_decrease=0.0, random_state=926159912),
 RecurrentTree(max_depth=3, min_

In [77]:
rrf_prediction=rrf.predict_mean_function(X=x,ids=ids)

In [78]:
rrf.oob_score_

0.7428571428571429

In [79]:
recurrent_concordance_index_score(predictions=rrf_prediction, X=x, event=event, ids=ids)

(0.7238095238095238, 0.2761904761904762)

In [80]:
rrf_prediction

{1: [0.23002813208222775,
  0.4490126133168241,
  0.6731637611033654,
  0.9012077633986486,
  1.0937398604629227,
  1.241775928523316,
  1.4035024707469794,
  1.4791698564684612,
  1.5365435922660373,
  1.5958253360315373,
  1.6450237475258573,
  1.6979840636327848,
  1.739210253046217,
  1.7831467598015396,
  1.8257499326279734,
  1.8706904071464994,
  1.9122975482437283,
  1.9536427847191837,
  1.9914046877937361,
  2.030083257413496,
  2.0564403996445453,
  2.0871308751028166,
  2.126678492457414,
  2.142773730205577,
  2.159273729674188,
  2.175273729290299,
  2.1921070620800216,
  2.2031070618350217,
  2.2126070616075215,
  2.215940394829744,
  2.224273727801966,
  2.2276070610241883,
  2.232607060774188,
  2.232607060774188],
 2: [0.3634646223155816,
  0.5712625881000515,
  0.6740605623428535,
  0.7335252066550995,
  0.770767629677222,
  0.8085100526768448,
  0.8381969204754647,
  0.8712928789144939,
  0.9011943930169111,
  0.9337777252351749,
  0.9756388341524818,
  1.0069166106