# Lights model with PBC2 dataset

In [1]:
# Library setup
%reset -f
%matplotlib inline
import os
os.environ['R_HOME'] = "/Library/Frameworks/R.framework/Versions/4.0/Resources"
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from lifelines.utils import concordance_index as c_index_score
from lights.inference import prox_QNEM
from lights.base.utils import visualize_vect_learning, plot_history
import numpy as np
from sklearn.preprocessing import StandardScaler
from lifelines import KaplanMeierFitter
import matplotlib.pyplot as plt
import seaborn as sns
from lifelines.utils import concordance_index as c_index_score
from prettytable import PrettyTable
from time import time
from competing_methods.all_model import load_data, extract_lights_feat, extract_R_feat
from matplotlib import rc
rc('text', usetex=True)
%matplotlib inline

def printmd(string):
    display(Markdown(string))
    
import rpy2.robjects as robjects
import warnings
%load_ext rpy2.ipython

## PBC_Seq

In [2]:
t = PrettyTable(['Algos', 'C_index', 'time'])
test_size = .2
simu = False
data, data_lights, Y_tsfresh, time_dep_feat, time_indep_feat = load_data(simu)
id_list = data_lights["id"]
nb_test_sample = int(test_size * len(id_list))
id_test = np.random.choice(id_list, size=nb_test_sample, replace=False)
data_lights_train = data_lights[~data_lights.id.isin(id_test)]
data_lights_test = data_lights[data_lights.id.isin(id_test)]
Y_tsfresh_train = Y_tsfresh[~Y_tsfresh.id.isin(id_test)]
Y_tsfresh_test = Y_tsfresh[Y_tsfresh.id.isin(id_test)]
X_lights_train, Y_lights_train, T_train, delta_train = \
    extract_lights_feat(data_lights_train, time_indep_feat, time_dep_feat)
X_lights_test, Y_lights_test, T_test, delta_test = \
    extract_lights_feat(data_lights_test, time_indep_feat, time_dep_feat)

data_train = data[~data.id.isin(id_test)]
data_test = data[data.id.isin(id_test)]
data_R_train, T_R_train, delta_R_train = extract_R_feat(data_train)
data_R_test, T_R_test, delta_R_test = extract_R_feat(data_test)

R[write to console]: Le chargement a nécessité le package : nlme

R[write to console]: Le chargement a nécessité le package : survival

R[write to console]: Le chargement a nécessité le package : doParallel

R[write to console]: Le chargement a nécessité le package : foreach

R[write to console]: Le chargement a nécessité le package : iterators

R[write to console]: Le chargement a nécessité le package : parallel

R[write to console]: Le chargement a nécessité le package : rstan

R[write to console]: Le chargement a nécessité le package : StanHeaders

R[write to console]: Le chargement a nécessité le package : ggplot2

R[write to console]: rstan (Version 2.21.2, GitRev: 2e1f913d3ca3)

R[write to console]: For execution on a local, multicore CPU with excess RAM we recommend calling
options(mc.cores = parallel::detectCores()).
To avoid recompilation of unchanged Stan programs, we recommend calling
rstan_options(auto_write = TRUE)



In [3]:
# The penalized time-dependent Cox model.
robjects.r.source(os.getcwd() + "/competing_methods/CoxNet.R")
X_R_train = robjects.r["Cox_get_long_feat"](data_R_train, time_dep_feat)
X_R_test = robjects.r["Cox_get_long_feat"](data_R_test, time_dep_feat)
best_lambda = robjects.r["Cox_cross_val"](X_R_train, T_R_train, delta_R_train)
start = time()
trained_CoxPH = robjects.r["Cox_fit"](X_R_train, T_R_train,
                                      delta_R_train, best_lambda)
Cox_pred = robjects.r["Cox_score"](trained_CoxPH, X_R_test)
Cox_marker = np.array(Cox_pred[:])
Cox_c_index = c_index_score(T_test, Cox_marker, delta_test)
Cox_c_index = max(Cox_c_index, 1 - Cox_c_index)
Cox_exe_time = time() - start

R[write to console]: Le chargement a nécessité le package : mvtnorm

R[write to console]: 
Attachement du package : ‘lcmm’


R[write to console]: The following objects are masked from ‘package:nlme’:

    fixef, ranef


R[write to console]: Le chargement a nécessité le package : Matrix

R[write to console]: Loaded glmnet 4.1-3



Be patient, hlme is running ... 
The program took 0.09 seconds 
Be patient, hlme is running ... 
The program took 0.15 seconds 
Be patient, hlme is running ... 
The program took 0.1 seconds 
Be patient, hlme is running ... 
The program took 0.11 seconds 
Be patient, hlme is running ... 
The program took 0.17 seconds 
Be patient, hlme is running ... 
The program took 0.16 seconds 
Be patient, hlme is running ... 
The program took 0.11 seconds 
Be patient, hlme is running ... 
The program took 0.02 seconds 
Be patient, hlme is running ... 
The program took 0.03 seconds 
Be patient, hlme is running ... 
The program took 0.03 seconds 
Be patient, hlme is running ... 
The program took 0.04 seconds 
Be patient, hlme is running ... 
The program took 0.03 seconds 
Be patient, hlme is running ... 
The program took 0.03 seconds 
Be patient, hlme is running ... 
The program took 0.03 seconds 


In [4]:
# Multivariate joint latent class model.
start = time()
robjects.r.source(os.getcwd() + "/competing_methods/MJLCMM.R")
trained_long_model, trained_mjlcmm = robjects.r["MJLCMM_fit"](data_R_train,
                                     robjects.StrVector(time_dep_feat),
                                     robjects.StrVector(time_indep_feat))
MJLCMM_pred = robjects.r["MJLCMM_score"](trained_long_model,
                                         trained_mjlcmm,
                                         time_indep_feat, data_R_test)
MJLCMM_marker = np.array(MJLCMM_pred.rx2('pprob')[2])
MJLCMM_c_index = c_index_score(T_test, MJLCMM_marker, delta_test)
MJLCMM_c_index = max(MJLCMM_c_index, 1 - MJLCMM_c_index)
MJLCMM_exe_time = time() - start

Be patient, multlcmm is running ... 
The program took 0.14 seconds 
The program took 1461.89 seconds 
The program took 0.19 seconds 


In [5]:
# Multivariate shared random effect model.
start = time()
robjects.r.source(os.getcwd() + "/competing_methods/JMBayes.R")
trained_JMBayes = robjects.r["fit"](data_R_train,
                                    robjects.StrVector(time_dep_feat),
                                    robjects.StrVector(time_indep_feat))
# JMBayes_pred = robjects.r["score"](trained_JMBayes, data_R_test, t_max=4)
# JMBayes_marker = np.array(JMBayes_pred.rx2('full.results')[0])
JMBayes_marker = np.array(robjects.r["score"](trained_JMBayes, data_R_test))
JMBayes_c_index = c_index_score(T_test, JMBayes_marker, delta_test)
JMBayes_c_index = max(JMBayes_c_index, 1 - JMBayes_c_index)
JMBayes_exe_time = time() - start

In [6]:
# lights
start = time()
fixed_effect_time_order = 1
fc_parameters = {
    "mean": None,
    "median": None,
    "quantile": [{"q": 0.25}, {"q": 0.75}]}
learner = prox_QNEM(fixed_effect_time_order=fixed_effect_time_order,
                      max_iter=5, initialize=True, print_every=1, l_pen_SGL=0.02, eta_sp_gp_l1=.9, l_pen_EN=0.02,
                     fc_parameters=fc_parameters)
learner.fit(X_lights_train, Y_lights_train, T_train, delta_train, Y_tsfresh_train)
lights_c_index = learner.score(X_lights_test, Y_lights_test, T_test, delta_test, Y_tsfresh_test)
lights_exe_time = time() - start

Launching the solver prox_QNEM...


Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 25.75it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 23.77it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 16.55it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 21.29it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 20.53it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.64it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.92it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 18.32it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 18.46it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 18.02it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 14.64it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 16.52it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 16.39it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 16.88it/s]
Feature Extraction: 100%|██████████| 20/20 [00:0

Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 20.72it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 18.32it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 20.92it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 20.64it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 20.06it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.21it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.09it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.69it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 17.37it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.45it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.16it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.15it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 19.05it/s]
Feature Extraction: 100%|██████████| 20/20 [00:01<00:00, 18.59it/s]
Feature Extraction: 100%|██████████| 20/20 [00:0

Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 21.05it/s]


Launching the solver MLMM...
Launching the solver ULMM...




Done solving using ULMM in 3.41e+00 seconds
 n_iter  |   obj    | rel_obj 
       0 |  3192.53 |      inf
       1 |  1168.38 | 6.34e-01
       2 |  862.032 | 2.62e-01
       3 |  688.533 | 2.01e-01
       4 |  580.748 | 1.57e-01
       5 |  510.801 | 1.20e-01
Done solving using MLMM in 4.49e+00 seconds
 n_iter  |   obj    | rel_obj 
       0 |  23.8921 |      inf
       1 |  12.7457 | 4.67e-01
       2 |  8.88471 | 3.03e-01
       3 |  6.73683 | 2.42e-01
       4 |  5.79546 | 1.40e-01
       5 |  5.25667 | 9.30e-02
Done solving using prox_QNEM in 3.02e+02 seconds


Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 76.77it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 104.12it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 92.98it/s] 
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 101.95it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 113.09it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 69.51it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 109.72it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 113.52it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 97.37it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 74.08it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 100.06it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 212.30it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 67.23it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 83.70it/s]
Feature Extraction: 100%|██████████| 20/

Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 68.17it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 71.07it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 103.09it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 84.05it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 89.35it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 73.61it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 90.82it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 71.85it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 85.48it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 122.47it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 85.24it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 77.98it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 74.24it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 136.31it/s]
Feature Extraction: 100%|██████████| 20/20 [0

Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 110.80it/s]
Feature Extraction: 100%|██████████| 20/20 [00:00<00:00, 79.16it/s]


In [7]:
t = PrettyTable(['Algos', 'C-Index', 'time'])
t.add_row(["Cox", "%g" % Cox_c_index, "%.3f" % Cox_exe_time])
t.add_row(["MJLCMM", "%g" % MJLCMM_c_index, "%.3f" % MJLCMM_exe_time])
t.add_row(["JMBayes", "%g" % JMBayes_c_index, "%.3f" % JMBayes_exe_time])
t.add_row(["lights", "%g" % lights_c_index, "%.3f" % lights_exe_time])
print(t)

+---------+----------+----------+
|  Algos  | C-Index  |   time   |
+---------+----------+----------+
|   Cox   | 0.770634 |  0.156   |
|  MJLCMM | 0.633362 | 1462.274 |
| JMBayes | 0.696785 | 862.912  |
|  lights | 0.662033 | 394.300  |
+---------+----------+----------+
