In [1]:
from survival_func import survival_fit


In [34]:
from scipy.optimize import minimize
import autograd.numpy as np
from autograd import grad, jacobian, hessian
import pandas as pd
from autograd import grad

In [35]:
censored_inputs = [[0.5,0.5,[-0.5,0.5,0.5]],[0.01,0.02,[-0.05,0.0,-0.8]]]

noncensored_inputs = [[1.01,0.50,[-0.09,0.97,-0.7]],[0.1,0.2,[-0.99,0.5,0.4]]]


In [7]:
fit = survival_fit(censored_inputs, noncensored_inputs, 0.5, 100)

In [5]:
#for i in range(1000):
 #   fit = survival_fit(censored_inputs, noncensored_inputs, 0.5, 100)

In [8]:
fit

array([ 0.59419715,  1.15357323, -0.16345502, -0.06708443,  0.15381426])

In [24]:
def f(param):
    
    return param[0]**2+param[1]**2

In [25]:
f([1.,1.])

2.0

In [26]:
df =  grad(f)

In [27]:
df([1.,1.])

[array(2.), array(2.)]

In [28]:
from autograd import hessian

In [29]:
ddf = hessian(f)

In [33]:
ddf(np.array([1.0,1.0]))

array([[2., 0.],
       [0., 2.]])

In [41]:
n_cens = len(censored_inputs)
n_noncens = len(noncensored_inputs)
n_rows = n_cens+n_noncens    
    
def training_loss(flatparam):
    
      
    arr = flatparam[2:]

    param = [flatparam[0], flatparam[1], arr]
       
        #Training loss is the negative log-likelihood.
        
    known_loss = np.log(np.array(mod_prob_density(noncensored_inputs, param))) #noncensored loss term 
    unknown_loss = np.log(np.array(mod_overall_survival(censored_inputs, param))) #censored loss term
    reg = np.dot(np.array(arr),np.array(arr))
    
    return 0.5*reg-1/n_rows*(np.sum(known_loss)+np.sum(unknown_loss))
        
training_gradient = grad(training_loss)

In [39]:
def covariate_exp(covariate_vector, gamma):
    '''Takes two vectors and returns a scalar. The vectors are in the form of a list of equal length. 
    '''
    dot_prod = np.dot(np.array(gamma), np.array(covariate_vector))
    
    #assert isinstance(dot_prod, float),"Dot prod. in 'covariate_exp' doesn't return float."
    
    return np.exp(dot_prod)
    
def susc_survival(time, covariate_vector, scaling, shape, gamma):
    '''This is the survival function of the susceptible individual. It is Eq 16 from paper.
    Parameters:
    ------------------------------------
    time: Positive scalar; either a float or an integer. 
    
    covariate_vector: A list. 
    
    gamma: A list of length matching 'covariate_vector'.
    
    scaling: positive float; comes from Weibull distribution.
    
    shape: positive float; comes from Weibull distribution.
    
    Returns:
    -----------------------------------
    
    Survival function of susceptible individual at given parameters assuming a Weibull distribution
    for the proportional hazards model with the baseline hazard function having given shape and scale.
    
    '''
    
    arg = (-(time/scaling)**shape)*(covariate_exp(covariate_vector, gamma))
    
    assert np.exp(arg)>=0,"Output of 'susc_survival' is not a nonnegative number."
    
    return np.exp(arg)
    
def overall_survival(time, prob, covariate_vector, scaling, shape, gamma):
    '''Same parameters as 'susc_survival' function above but in addition has:
    prob: Estimated probabilities (to be returned by HardEM); float type between 0 and 1.
    It returns the overall survival function of a (not necessarily susc.) individual. 
    This is Eq 17 from paper. 
    '''
    
    out = prob+(1-prob)*susc_survival(time, covariate_vector, scaling, shape, gamma)
    
    assert 0<=out<=1,"Output of 'overall_survival' is not in [0,1]."
    
    return out
    
        
def prob_density(time, prob, covariate_vector, scaling, shape, gamma):
    '''
    Same parameters as above and returns overall prob density for time of event. This is Eq 18
    from the paper.
    '''
    
    def time_slice(time_param):
        return overall_survival(time_param, prob, covariate_vector, scaling, shape, gamma)
    
    out = -grad(time_slice)(float(time))
    
    assert time>0, 'time<=0' 
    assert (0<=prob<=1), 'prob out of [0,1]' 
    assert scaling>0, 'scalring not positive'
    assert shape>0, 'shape not positive'
    
    #assert out>0, "Output of prob_density is not positive" 
    
    return out
    
def mod_prob_density(Array, param):

    '''
    This is a modification of the 'prob_density' function to be used in the training loss below.
    
    Parameters:
    
    --------------------------
    
    Array: An array of arrays the form [[time, prob, covariate_vector],...]. Here covariate_vector is an array and 
    the others are floats. 
    
    param: An array of the form [scaling, shape, gamma]. Here gamma is an array and the others are floats. 

    Returns:
    
    ---------------------------

    A list of the prob_density function applied to each array in Array with respect to param.    
    '''
    
    out = [prob_density(arr[0], arr[1], arr[2], param[0], param[1], param[2]) for arr in Array]
    
    assert isinstance(out, list),"Output of 'mod_prob_density' is not a list."
    
    return out
    
def mod_overall_survival(Array, param):
    
    '''
    Same modification as mod_prob_density function but for the overall_survival function.
    '''
    
    out = [overall_survival(arr[0], arr[1], arr[2], param[0], param[1], param[2]) for arr in Array]
    
    assert isinstance(out, list),"Output of 'mod_prob_density' is not a list."
    
    return out

In [42]:
training_loss([0.1, 0.1, 0.1, 0.1, 0.1])

1.6172114357403882

In [43]:
ddtrain = hessian(training_loss)

In [53]:
matrix = ddtrain(np.array([0.1, 0.7, 0.1, 0.1, 0.1]))

In [54]:
def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

In [55]:
is_pos_def(matrix)

True

In [50]:
def check_symmetric(a, rtol=1e-05, atol=1e-08):
    return np.allclose(a, a.T, rtol=rtol, atol=atol)

In [56]:
check_symmetric(matrix, rtol=1e-05, atol=1e-08)

True

In [57]:
matrix

array([[ 1.51633796e+02, -3.04256182e+01,  2.31570882e+00,
        -9.36370565e+00,  6.07653389e+00],
       [-3.04256182e+01,  7.95911950e+00, -2.10241696e-01,
         2.82965521e+00, -2.04953222e+00],
       [ 2.31570882e+00, -2.10241696e-01,  1.23719109e+00,
        -2.18635036e-01,  7.55953992e-04],
       [-9.36370565e+00,  2.82965521e+00, -2.18635036e-01,
         2.25450461e+00, -8.38875706e-01],
       [ 6.07653389e+00, -2.04953222e+00,  7.55953992e-04,
        -8.38875706e-01,  1.68145660e+00]])