In [2]:
import numpy as np
import pandas as pd
from sksurv.datasets import load_whas500
from sksurv.linear_model import CoxPHSurvivalAnalysis
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
X, y = load_whas500()
X = X.astype(float)
# Combining features and events is easier to work with for now
combined = pd.concat([X, pd.DataFrame(y)], axis=1)
combined['lenfol'] = combined['lenfol'].astype(int)
TARGET_COLUMNS = ['fstat', 'lenfol']



right_censored = np.array([el[0] for el in y])
event_times = np.array([el[1] for el in y]).astype(int)

# Constructing the components
In order to solve equation 8we need to filter and group the data


## $D_t$
We need to group the records on event time, ignore the right-censored records

Then we get $D_t$ for every $t$ from $t=1$ to $T$

In [4]:
# First ignore all right-censored records
dt = combined[~combined['fstat']]

# We don't need the censor column anymore
dt = dt.drop(['fstat'], axis=1)

# Group on event time
dt = dt.groupby('lenfol')

dt.describe().head()

Unnamed: 0_level_0,afb,afb,afb,afb,afb,afb,afb,afb,age,age,...,sho,sho,sysbp,sysbp,sysbp,sysbp,sysbp,sysbp,sysbp,sysbp
Unnamed: 0_level_1,count,mean,std,min,25%,50%,75%,max,count,mean,...,75%,max,count,mean,std,min,25%,50%,75%,max
lenfol,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
368,1.0,0.0,,0.0,0.0,0.0,0.0,0.0,1.0,46.0,...,0.0,0.0,1.0,149.0,,149.0,149.0,149.0,149.0,149.0
371,3.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.0,73.333333,...,0.0,0.0,3.0,132.333333,18.610033,115.0,122.5,130.0,141.0,152.0
373,1.0,0.0,,0.0,0.0,0.0,0.0,0.0,1.0,65.0,...,0.0,0.0,1.0,164.0,,164.0,164.0,164.0,164.0,164.0
376,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,60.0,...,0.0,0.0,2.0,195.0,22.627417,179.0,187.0,195.0,203.0,211.0
386,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,55.5,...,0.0,0.0,2.0,141.5,34.648232,117.0,129.25,141.5,153.75,166.0


## $R_t$
$R_t$ denotes the set of samples at risk of the event at time $t$. This includes samples with an event at time t, the samples with an event later than time t, and right-censored samples.

*I __think__ that I can treat right-censored samples the same as regular samples for this set.*

In [5]:
rt = combined

# I will create a bucket per unique lenfol and create a new dataframe per bucket with all samples at risk at that time
unique_times = rt['lenfol'].unique()

num_unique_times = len(unique_times)


In [6]:
np.unique(np.arange(4))

array([0, 1, 2, 3])

In [7]:
def group_samples_at_risk(event_times: np.array):
    """
    Groups the indices of samples on whether they are at risk at a certain time.
    
    A sample is at risk at a certain time when its event time is greater or equal that time.
    """
    unique_times = np.unique(event_times)
    
    grouped = {}
    
    for t in unique_times:
        grouped[t] = np.argwhere(event_times>= t)
    
    return grouped

Rt = group_samples_at_risk(event_times)




## $\sum \limits_{t=1}^{T} \sum \limits_{n \in D_t} \mathbf{x}_{nk}$

$D_t$ is the list of indices with an observed event at time $t$.

This part seems to be constant throughout the optimization?

I think this is just a big fat sum of all the patients' covariants. It will stay constant per institution.


In [8]:
covariates_sum = combined.drop(TARGET_COLUMNS, axis=1).values.sum(axis=0)

covariates_sum


array([7.800000e+01, 3.492300e+04, 1.100000e+01, 1.330689e+04,
       1.550000e+02, 3.750000e+02, 3.913300e+04, 2.000000e+02,
       4.350900e+04, 3.058000e+03, 1.710000e+02, 1.530000e+02,
       2.200000e+01, 7.235200e+04])

In [9]:
# Covariates
X.values

array([[  1.,  83.,   0., ...,   0.,   0., 152.],
       [  0.,  49.,   0., ...,   1.,   0., 120.],
       [  0.,  70.,   0., ...,   1.,   0., 147.],
       ...,
       [  1.,  57.,   0., ...,   0.,   0., 120.],
       [  0.,  67.,   0., ...,   1.,   0., 112.],
       [  0.,  98.,   0., ...,   1.,   0., 160.]])

## Local update

$ \beta_k^{(p)} = \bigg[ \rho \sum \limits_{n=1}^{N} \mathbf{x}_{nk}\mathbf{x}_{nk}^T\bigg]^{-1} \cdot \bigg[\sum \limits_{n=1}^N  (\rho z_{nk}^{(p-1)} - \gamma_{nk}^{p-1}) \mathbf{x}_{nk} + \sum \limits_{t=1}^T \sum \limits_{n \in D_t} \mathbf{x}_{nk}\bigg] $

According to the paper, $\mathbf{\beta}$ will not be returned but $\sigma$, which is computed as follows:
$\sigma_{nk} = \mathbf{\beta}_k^T \mathbf{x}_{nk}$ where $k=1,...,K$

### Performance optimization
There are two parts to this computation that seem to be constant over iterations:
1. $\rho \sum \limits_{n=1}^{N} \mathbf{x}_{nk}\mathbf{x}_{nk}^T$
2. $\sum \limits_{t=1}^T \sum \limits_{n \in D_t} \mathbf{x}_{nk}$

Number 2. is also the part where we need to apply the scalar product protocol. We need this because (TODO: verify) we need to filter out right-censored samples.

I think we can rewrite number 2. to:

$\sum \limits_{n \in E} \mathbf{x}_{nk}$

Where $E$ is the collection of samples that are NOT right-censored.



In [10]:
# Local update
RHO = 0.25

# Parts that stay constant over iterations
# Square all covariates and sum them together
# The formula says for every patient, x needs to be multiplied by itself.
# Squaring all covariates with themselves comes down to the same thing since x_nk is supposed to
# be one-dimensional
multiplied_covariates = (X* X.transpose()).sum(axis=0)
covariates_summed = combined.drop(TARGET_COLUMNS, axis=1).values.sum(axis=0)

def sum_covariates(covariates: np.array):
    return np.sum(covariates, axis=0)
    
def multiply_covariates(covariates: np.array):
    return np.square(covariates).sum()

def elementwise_multiply_sum(one_dim: np.array, two_dim: np.array):
    """
    Every element in one_dim does elementwise multiplication with its corresponding row in two_dim.
    
    All rows of the result will be summed together vertically.
    """
    multiplied = np.zeros(two_dim.shape)
    for i in range(one_dim.shape[0]):
        multiplied[i] = one_dim[i] * two_dim[i]
        
    return multiplied.sum(axis=0)
    
    

def compute_beta(covariates:np.array, z:np.array, gamma:np.array, rho,
                 multiplied_covariates, covariates_sum):
    
    first_component = 1/(rho * multiplied_covariates)
    
    pz = rho * z
    
    second_component = elementwise_multiply_sum(pz - gamma, covariates) + covariates_sum    
    
    return second_component/ first_component

def compute_sigma(beta, covariates):
    return np.matmul(covariates, beta)

def local_update(covariates:np.array, z:np.array, gamma:np.array, rho,
                 multiplied_covariates, covariates_sum):
    
    beta = compute_beta(covariates, z, gamma, rho, multiplied_covariates, covariates_sum)
    
    return compute_sigma(beta, covariates)
    
    

## Tests

In [11]:
def test_sum_covariates_returns_one_dim_array():
    num_patients = 2
    num_features = 2
    
    covariates = np.arange(num_patients * num_features).reshape((num_patients, num_features))
    
    result = sum_covariates(covariates)
    assert result.shape == (num_features, ), f'Result is not one dimensional but shape {result.shape}'

def test_multiply_covariates_returns_scalar():
    num_patients = 2
    num_features = 2
    
    covariates = np.arange(num_patients * num_features).reshape((num_patients, num_features))
    
    result = multiply_covariates(covariates)
    assert np.isscalar(result) , f'Result is not scalar but shape {result.shape}'

def test_elementwise_multiply_sum():
    two_dim = np.array([[1,2], [3,4], [5,6]])
    one_dim = np.array([1,2,3])
    
    result = elementwise_multiply_sum(one_dim, two_dim)
    
    assert result.shape == (two_dim.shape[1], ), f'Result shape is not same as number of columns in two_dim ({two_dim.shape[1]}) but {result.shape}'
    
    np.testing.assert_array_equal(result, np.array([22, 28]))
    
    
def test_local_update():
    num_patients = 3
    num_features = 2
    
    rho=1
    covariates = np.arange(num_patients*num_features).reshape((num_patients, num_features))
    z = np.arange(num_patients)
    gamma = np.arange(num_patients)
    multiplied_cov = multiply_covariates(covariates)
    summed_cov = sum_covariates(covariates)
    
    sigma = local_update(covariates, z, gamma, rho, multiplied_cov, summed_cov)
    
    assert sigma.shape == (num_patients, ), f'Updated value is not an array of shape {(num_features, )} but of shape: {updated}'

test_sum_covariates_returns_one_dim_array()
test_multiply_covariates_returns_scalar()
test_elementwise_multiply_sum()
test_local_update()

## Server update
- Server computes:
    - $\overline{\sigma}_n^{(p)} = \sum \limits_{k=1}^K \sigma_{nk}^{(p)}/K $
    - $\overline{\gamma}_{n}^{(p)} = \sum \limits_{k=1}^K \gamma_{nk}^{(p)}/K $
- Server computes $\overline{z}^{(p)}$ by applying Newton-Raphson to:
$ \sum_{t=1}^T \left[d_t log \sum \limits_{j \in R_t} exp(K \overline{z}_j) \right] + K \rho \sum \limits_{n=1}^N \left[ \frac{\overline{z}_n^2}{2} - 
\left( \overline{\sigma}_n^{(p)} + \frac{\overline{\gamma}_n^{(p-1)}}{\rho} \right) \overline{z}_n \right]    $

### Person-level auxiliary variables
For the update the server makes use of the auxiliary variables $\overline{\sigma}$ and $\overline{\gamma}$. The elements of these vectors have a one-on-one relationship with the patients.

Moreover, the server tries to find a variable $\overline{z}$ which not only has a one-on-one relationship with the patients, but also needs to be grouped based on patients' event times.

### Optimization method
In order to get to a working end result I will skip the step of implementing the gradient and the Hessian. I will use the default method from `scipy.optimize.minimize`.

In [15]:
K = 1 #Number of institutions
dt = num_unique_times # Number of unique event times

def L_z_parametrized(z: np.array, K: int, gamma:np.array, sigma, rho, samples_at_risk):
    
    dt = len(Rt)
    
    component1 = L_z_component1(z, samples_at_risk, dt)
    component2 = L_z_component2(z, K, sigma,gamma, rho)
    
    return component1 + component2
        
def L_z_component1(z, samples_at_risk, unique_event_times):
    result = 0
    for t, group in samples_at_risk.items():
        z_at_risk = z[group]
        result += dt * (K * np.exp(z_at_risk)).sum()
    
    return result
        
def L_z_component2(z, K, sigma, gamma, rho):
    element_wise = np.square(z)/2 - sigma + (gamma/rho) * z
    return K * rho * element_wise.sum()
    

# Test if the output type is as expected (should be a scalar)

def test_lz_outputs_scalar():
    num_patients, num_features = 3, 2
    num_parties = 1
    samples_at_risk = {1: [0], 2: [1]}
    
    z = np.arange(num_patients)
    gamma = z
    sigma = z
    rho = 2
    
    result = L_z_parametrized(z, num_parties, gamma, sigma, rho, samples_at_risk)
    
    assert np.isscalar(result)
    
test_lz_outputs_scalar()
    


    

## TODO: Implement derivatives
As mentioned before, I'm skipping the implementaiton of the first- and second order partial derivatives for now. I will try to work with the default option of scipy that works without derivatives.
Later on I will see if the computation can be sped up by using newton-raphson.

In [63]:
from scipy.optimize import minimize

def compute_z(num_parties, gamma, sigma, rho, samples_at_risk, z_start):
    L_z = lambda z: L_z_parametrized(z, num_parties, gamma, sigma, rho, samples_at_risk)
    
    minimum = minimize(L_z, z_start)
    
    return minimum.x
        

      fun: 40.43530755237292
 hess_inv: array([[ 0.0988872 , -0.00172494],
       [-0.00172494,  0.10940803]])
      jac: array([8.10623169e-06, 4.76837158e-06])
  message: 'Optimization terminated successfully.'
     nfev: 63
      nit: 18
     njev: 21
   status: 0
  success: True
        x: array([-3.91971792, -4.02566807])

## Risks
### Differential privacy-ish
If the difference between $D_t$ and $D_{t+1}$, and similarly, the difference between $R_t$ and $R_{t+1}$ is too small, there is a great risk of data leakage. This needs to be addressed.

### "Gradient" leakage
The central server computes a variable $\boldsymbol{\overline{z}}$ which is a vector where every element corresponds to an individual patient.

# Next steps
The next step is to put the puzzle pieces together into some kind of datanode and central node entities.

I think it is best if I start to move to regular python modules now