In [1]:
# google colab specific - installing probcox
!pip3 install torch==1.7.0
!pip3 install pyro-ppl==1.5.1
!pip3 install probcox

Collecting torch==1.7.0
  Downloading torch-1.7.0-cp37-cp37m-manylinux1_x86_64.whl (776.7 MB)
[K     |████████████████████████████████| 776.7 MB 4.9 kB/s 
Collecting dataclasses
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Installing collected packages: dataclasses, torch
  Attempting uninstall: torch
    Found existing installation: torch 1.10.0+cu111
    Uninstalling torch-1.10.0+cu111:
      Successfully uninstalled torch-1.10.0+cu111
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.7.0 which is incompatible.
torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.7.0 which is incompatible.
torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.7.0 which is incompatible.[0m
Successfully installed dataclasses-0.6 torch-1.7.0
Collecting pyro-ppl==1.5.1
  D

In [2]:
# Modules
# =======================================================================================================================
import os
import sys
import time
import shutil
import subprocess
import tqdm

import numpy as np
import pandas as pd

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist

from pyro.infer import SVI, Trace_ELBO

import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")

import probcox as pcox

dtype = torch.FloatTensor

# Set Seed
np.random.seed(293)
torch.manual_seed(34)

<torch._C.Generator at 0x7fc3ffa33090>

In [3]:
# Simulation Settings
# =======================================================================================================================

I = 10000 # Number of Individuals
P_binary = 0
P_continuous = 800
P = P_binary + P_continuous
theta = np.random.normal(0, 0.75, 25)[:, None]
theta = np.concatenate((np.zeros((P_continuous-25, 1)), theta))
scale = 15  # Scaling factor for Baseline Hazard

# Simulation
# =======================================================================================================================
# Class for simulation
TVC = pcox.TVC(theta=theta, P_binary=P_binary, P_continuous=P_continuous, dtype=dtype)

# Sample baseline hazard - scale is set to define censorship/events
TVC.make_lambda0(scale=scale)

# Sample Data
run_id=1 # change for different dataset 1-5
np.random.seed(run_id)
torch.manual_seed(run_id)
surv = torch.zeros((0, 3))
X = torch.zeros((0, P))
for __ in tqdm.tqdm(range(I)):
    a, b = TVC.sample()
    surv = torch.cat((surv, a))
    X = torch.cat((X, b))

total_obs = surv.shape[0]
total_events = torch.sum(surv[:, -1] == 1).numpy().tolist()
print('Obs: ', total_obs)
print('Censorship: ', 1-total_events/I)

100%|██████████| 10000/10000 [04:15<00:00, 39.08it/s]

Obs:  70755
Censorship:  0.6809000000000001





In [4]:
# Inference Setup
# =======================================================================================================================
# Custom linear predictor - Here: simple linear combination
def predictor(data):
    theta =  pyro.sample("theta", dist.StudentT(1, loc=0, scale=0.001).expand([data[1].shape[1], 1])).type(dtype)
    pred = torch.mm(data[1], theta)
    return(pred)

def evaluate(surv, X, rank, batchsize, sampling_proportion, iter_, run_suffix, predictor=predictor):
    sampling_proportion[1] = batchsize
    eta=1 # paramter for optimization
    run = True # repeat initalization if NAN encounterd while training - gauge correct optimization settings
    while run:
        run = False
        pyro.clear_param_store()
        m = pcox.PCox(sampling_proportion=sampling_proportion, predictor=predictor)
        m.initialize(eta=eta, rank=rank, num_particles=3)
        loss=[0]
        for ii in tqdm.tqdm(range((iter_))):
            idx = np.unique(np.concatenate((np.random.choice(np.where(surv[:, -1]==1)[0], 1, replace=False), np.random.choice(range(surv.shape[0]), batchsize, replace=False))))[:batchsize]
            data=[surv[idx], X[idx]] # subsampled data
            loss.append(m.infer(data=data))
            # divergence check
            if loss[-1] != loss[-1]:
                eta = eta * 0.1
                run=True
                break
    g = m.return_guide()
    out = g.quantiles([0.025, 0.5, 0.975])
    return(out)

In [5]:
start_time = time.time()
pyro.clear_param_store()
out = evaluate(run_suffix='', rank=10, batchsize=512, iter_=25000, surv=surv, X=X, sampling_proportion=[total_obs, None, total_events, None])
print("--- %s seconds ---" % (time.time() - start_time))
theta_iden = np.sign(out['theta'][0].detach().numpy()) == np.sign(out['theta'][2].detach().numpy())

  0%|          | 1/25000 [00:00<3:38:26,  1.91it/s]
100%|██████████| 25000/25000 [08:37<00:00, 48.31it/s]

--- 518.0189940929413 seconds ---





# R evaluation

In [6]:
# prepare data frame to sent to R
rd = pd.DataFrame(np.concatenate((surv.numpy(), X.numpy()), axis=1))
rd.columns = ['V' + str(ii) for ii in range(1, rd.shape[1]+1)]

rd_theta = pd.DataFrame(theta)
rd_theta.columns = ['V0']

In [7]:
# load R 
%load_ext rpy2.ipython

In [8]:
# install packages
%%R 
install.packages('survival')
install.packages('glmnet')
install.packages('devtools')
library(devtools)
install_github("michaelyanwang/dcalasso")

R[write to console]: Installing package into ‘/usr/local/lib/R/site-library’
(as ‘lib’ is unspecified)

R[write to console]: trying URL 'https://cran.rstudio.com/src/contrib/survival_3.3-1.tar.gz'

R[write to console]: Content type 'application/x-gzip'
R[write to console]:  length 6577371 bytes (6.3 MB)

R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[write to console]: =
R[wr

* checking for file ‘/tmp/RtmpgTh2hL/remotes4039f6392b/michaelyanwang-dcalasso-3574e07/DESCRIPTION’ ... OK
* preparing ‘dcalasso’:
* checking DESCRIPTION meta-information ... OK
* cleaning src
* checking for LF line-endings in source and make files and shell scripts
* checking for empty or unneeded directories
Omitted ‘LazyData’ from DESCRIPTION
* building ‘dcalasso_0.1.0.tar.gz’



R[write to console]: Installing package into ‘/usr/local/lib/R/site-library’
(as ‘lib’ is unspecified)



In [9]:
%%R -i rd -i theta -o beta_hat

set.seed(13)
library(glmnet)
library(survival)
library(dcalasso)

# dcalasso
start_time <- Sys.time()
mod = dcalasso(as.formula(paste0('Surv(V1,V2,V3)~',paste(paste0('V',4:803),collapse='+'))), family = 'cox.ph', data=rd, K = 3, iter.os = 4, ncores = 2)   
end_time <- Sys.time()
print(end_time - start_time)
beta_hat <- unname(mod$coefficients.pen)

R[write to console]: Loading required package: Matrix

R[write to console]: Loaded glmnet 4.1-3

R[write to console]: Loading required package: mgcv

R[write to console]: Loading required package: nlme

R[write to console]: This is mgcv 1.8-39. For overview type 'help("mgcv-package")'.

R[write to console]: Loading required package: doParallel

R[write to console]: Loading required package: foreach

R[write to console]: Loading required package: iterators

R[write to console]: Loading required package: parallel

R[write to console]: Loading required package: MASS



Time difference of 12.02879 mins


In [10]:
print('R: ', (np.abs(theta[-25:]-beta_hat[-25:])).mean())
print('ProbCox: ', (np.abs(theta[-25:]-out['theta'][1].detach().numpy()[-25:, 0])).mean())
print('R - TP:', (beta_hat[-25:] != 0).sum(), ', FP: ', (beta_hat[:-25] != 0).sum())
print('ProbCox - TP:', theta_iden[-25:].sum(), ', FP: ', theta_iden[:-25].sum())

R:  0.7593090408975961
ProbCox:  0.7809955076965891
R - TP: 25 , FP:  4
ProbCox - TP: 25 , FP:  0
