In [2]:
import numpy as np
import pandas as pd
from sksurv.datasets import load_whas500
from sksurv.linear_model import CoxPHSurvivalAnalysis
import seaborn as sns
import matplotlib.pyplot as plt

In [11]:
X, y = load_whas500()
X = X.astype(float)
# Combining features and events is easier to work with for now
combined = pd.concat([X, pd.DataFrame(y)], axis=1)
combined['lenfol'] = combined['lenfol'].astype(int)
TARGET_COLUMNS = ['fstat', 'lenfol']


combined

Unnamed: 0,afb,age,av3,bmi,chf,cvd,diasbp,gender,hr,los,miord,mitype,sho,sysbp,fstat,lenfol
0,1.0,83.0,0.0,25.54051,0.0,1.0,78.0,0.0,89.0,5.0,1.0,0.0,0.0,152.0,False,2178
1,0.0,49.0,0.0,24.02398,0.0,1.0,60.0,0.0,84.0,5.0,0.0,1.0,0.0,120.0,False,2172
2,0.0,70.0,0.0,22.14290,0.0,0.0,88.0,1.0,83.0,5.0,0.0,1.0,0.0,147.0,False,2190
3,0.0,70.0,0.0,26.63187,1.0,1.0,76.0,0.0,65.0,10.0,0.0,1.0,0.0,123.0,True,297
4,0.0,70.0,0.0,24.41255,0.0,1.0,85.0,0.0,63.0,6.0,0.0,1.0,0.0,135.0,False,2131
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,1.0,76.0,0.0,27.96454,0.0,1.0,88.0,1.0,68.0,1.0,0.0,1.0,0.0,112.0,True,10
496,0.0,76.0,0.0,24.26862,0.0,1.0,96.0,1.0,88.0,3.0,0.0,0.0,0.0,208.0,False,662
497,1.0,57.0,0.0,42.13576,0.0,1.0,74.0,1.0,123.0,3.0,0.0,0.0,0.0,120.0,False,725
498,0.0,67.0,0.0,27.40905,0.0,1.0,62.0,0.0,59.0,1.0,0.0,1.0,0.0,112.0,False,532


# Constructing the components
In order to solve equation 8we need to filter and group the data


## $D_t$
We need to group the records on event time, ignore the right-censored records

Then we get $D_t$ for every $t$ from $t=1$ to $T$

In [12]:
# First ignore all right-censored records
dt = combined[~combined['fstat']]

# We don't need the censor column anymore
dt = dt.drop(['fstat'], axis=1)

# Group on event time
dt = dt.groupby('lenfol')

dt.describe().head()

Unnamed: 0_level_0,afb,afb,afb,afb,afb,afb,afb,afb,age,age,...,sho,sho,sysbp,sysbp,sysbp,sysbp,sysbp,sysbp,sysbp,sysbp
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
lenfol,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
368,1.0,0.0,,0.0,0.0,0.0,0.0,0.0,1.0,46.0,...,0.0,0.0,1.0,149.0,,149.0,149.0,149.0,149.0,149.0
371,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,73.333333,...,0.0,0.0,3.0,132.333333,18.610033,115.0,122.5,130.0,141.0,152.0
373,1.0,0.0,,0.0,0.0,0.0,0.0,0.0,1.0,65.0,...,0.0,0.0,1.0,164.0,,164.0,164.0,164.0,164.0,164.0
376,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,60.0,...,0.0,0.0,2.0,195.0,22.627417,179.0,187.0,195.0,203.0,211.0
386,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,55.5,...,0.0,0.0,2.0,141.5,34.648232,117.0,129.25,141.5,153.75,166.0


## $R_t$
$R_t$ denotes the set of samples at risk of the event at time $t$. This includes samples with an event at time t, the samples with an event later than time t, and right-censored samples.

*I __think__ that I can treat right-censored samples the same as regular samples for this set.*

In [13]:
rt = combined

# I will create a bucket per unique lenfol and create a new dataframe per bucket with all samples at risk at that time
unique_times = rt['lenfol'].unique()

unique_times


array([2178, 2172, 2190,  297, 2131,    1, 2122, 1496,  920, 2175, 2173,
       1671, 2192,  865, 2166, 2168,  905, 2353, 2146,   61, 2358, 2114,
       2132, 2139, 2048, 2152,    6, 2156,  118, 2064,  849,  714, 2057,
          2,    7, 2151,  422,  354, 2065, 1065,  535, 2118,   97, 2113,
        100, 2032, 1317, 2126, 2123,  670,  343,    3, 2009,   64, 1994,
       1579, 1993, 1955,   42, 1964, 1548,  446, 1976, 1942,  151, 2006,
       2086, 1969, 1939, 1940, 1576, 1941,  197, 1933,   95, 2160, 2084,
       2145, 2125, 1920,    4, 1553,  235,  192, 1233,   88, 1954,  903,
        612, 2025, 1887,  187,  101, 1885,  936,  363, 1048, 1977, 1936,
       1889, 1923,   11, 2100, 1914, 1883,   33, 1931, 1506, 1858, 1854,
       1847,   46, 2061, 1893, 2108,   83, 1377, 1863, 1880, 1359, 1831,
       1836, 1159,  113, 1217, 1899, 1934, 1527, 1979, 1232, 2066, 1624,
        530, 1096,  345, 1919, 1577, 1904, 2083,  146, 2350, 1926,  718,
       1451,  358,  465, 1381, 1385, 1346, 1338,  1

In [33]:
def group_samples_at_risk(samples: pd.DataFrame):
    unique_times = samples['lenfol'].unique()
    
    grouped = {}
    
    for t in unique_times:
        grouped[t] = samples[samples['lenfol']>= t]
        
    return grouped

Rt = group_samples_at_risk(combined)



# Testing if the resulting list descends in numbers
previous_length = len(combined) + 1

for t in sorted(grouped.keys()):
    length = len(grouped[t])
    
    assert length < previous_length
    
    previous_length = length

## $\sum \limits_{t=1}^{T} \sum \limits_{n \in D_t} \mathbf{x}_{nk}$
This part seems to be constant throughout the optimization?

I think this is just a big fat sum of all the patients' covariants. It will stay constant per institution.

In [4]:
covariates_sum = combined.drop(TARGET_COLUMNS, axis=1).values.sum(axis=0)

covariates_sum


array([7.800000e+01, 3.492300e+04, 1.100000e+01, 1.330689e+04,
       1.550000e+02, 3.750000e+02, 3.913300e+04, 2.000000e+02,
       4.350900e+04, 3.058000e+03, 1.710000e+02, 1.530000e+02,
       2.200000e+01, 7.235200e+04])

In [5]:
# Covariates
X.values

array([[  1.,  83.,   0., ...,   0.,   0., 152.],
       [  0.,  49.,   0., ...,   1.,   0., 120.],
       [  0.,  70.,   0., ...,   1.,   0., 147.],
       ...,
       [  1.,  57.,   0., ...,   0.,   0., 120.],
       [  0.,  67.,   0., ...,   1.,   0., 112.],
       [  0.,  98.,   0., ...,   1.,   0., 160.]])

## Local update

$ \beta_k^{(p)} = \bigg[ \rho \sum \limits_{n=1}^{N} \mathbf{x}_{nk}\mathbf{x}_{nk}^T\bigg]^{-1} \cdot \bigg[\sum \limits_{n=1}^N  (\rho z_{nk}^{(p-1)} - \gamma_{nk}^{p-1}) \mathbf{x}_{nk} + \sum \limits_{t=1}^T \sum \limits_{n \in D_t} \mathbf{x}_{nk}\bigg] $

In [6]:
# Local update
RHO = 0.25

multiplied_covariates = (X* X.transpose()).sum(axis=0)

def local_update(covariates:np.array, events, z:np.array, gamma:np.array, previous_z, previous_gamma, rho=RHO):
    # Group patients on event times, then sum all covariates
    
    
    # square all covariates and sum them together
    first_component = 1/(RHO * (np.square(covariates).sum()))
    
    pz = np.multiply(rho, previous_z)
    
    second_component = np.multiply(pz - previous_gamma, covariates) + covariates_sum    
    
    return first_component * second_component


## Server update
- Server computes:
    - $\overline{\sigma}_n^{(p)} = \sum \limits_{k=1}^K \sigma_{nk}^{(p)}/K $
    - $\overline{\gamma}_{n}^{(p)} = \sum \limits_{k=1}^K \gamma_{nk}^{(p)}/K $
- Server computes $\overline{z}^{(p)}$ by applying Newton-Raphson to:
$ \sum_{t=1}^T \left[d_t log \sum \limits_{j \in R_t} exp(K \overline{z}_j) \right] + K \rho \sum \limits_{n=1}^N \left[ \frac{\overline{z}_n^2}{2} - 
\left( \overline{\sigma}_n^{(p)} + \frac{\overline{\gamma}_n^{(p-1)}}{\rho} \right) \overline{z}_n \right]    $

$ \left[ \right] $

In [None]:
K = 1 #Number of institutions

def L_z(z):
    

In [34]:
import scipy


In [35]:
def (a:np.array):
    return a * a.transpose()



SyntaxError: invalid syntax (1779160323.py, line 1)