# Benchmarks

## Initialize

In [None]:
%load_ext autoreload
%autoreload 2

import os
from tqdm.auto import tqdm
import pathlib

import numpy as np
import pandas as pd
import lifelines
import pandas as pd

from joblib import Parallel, delayed
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore")
import shutil

import plotly.express as px
import plotly.graph_objects as go
from plotly.graph_objects import Box

import matplotlib.pyplot as plt
from lifelines import CRCSplineFitter
import warnings
from lifelines.utils import CensoringType

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import math

from IPython.display import clear_output
import pathlib

from lifelines.utils import concordance_index

In [None]:
from dask.distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=1, threads_per_worker=10)
client = Client(cluster)

In [None]:
!hostname

In [None]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_name = "210714_metabolomics"
#data_path = "/data/analysis/ag-reils/steinfej"
data_pre = f"{base_path}/data/2_datasets_pre/{project_name}"
data_post = f"{base_path}/data/3_datasets_post/{project_name}"

project_label = "21_metabolomics_multitask"
project_path = f"{base_path}/results/projects/{project_label}"
figures_path = f"{project_path}/figures"
data_results_path = f"{project_path}/data"
pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)

In [None]:
endpoints = [
    #Cardiovascular
    'M_MACE',
    'M_coronary_heart_disease',
    'M_cerebral_stroke',
    'M_peripheral_arterial_disease',
    'M_atrial_fibrillation',
    'M_heart_failure',
    'M_abdominal_aortic_aneurysm',
    'M_venous_thrombosis',
    
    # General IM
    'M_type_2_diabetes',
    'M_liver_disease',
    'M_renal_disease',
    
    # Pulmological
    'M_asthma', 
    'M_chronic_obstructuve_pulmonary_disease',  
    
    # Psychiatric/Neurological
    'M_all_cause_dementia',
    'M_parkinsons_disease',   
    
    # Cancers
    "M_lung_cancer",
    "M_non_melanoma_skin_cancer",
    "M_colon_cancer",
    "M_rectal_cancer",
    "M_prostate_cancer",
    "M_breast_cancer",
    
    # Ophtalmological
    'M_cataracts', 
    'M_glaucoma',
    
    # Traumatology
    'M_fractures',
]

In [None]:
partitions = [str(p) for p in range(0, 22)]
splits = ["train", "valid", "test"]

In [None]:
!ls {data_results_path}

### Load models

In [None]:
import joblib
def get_cph(path): 
    with open(path,'rb') as f:
        cph = pickle.load(f)
    return cph

In [None]:
run = "220126"

In [None]:
loghs = pd.read_feather(f"{data_results_path}/loghazards_model_{run}_metabolomics.feather").query("split=='test'")
cols = [col for col in loghs.columns if not "AgeSex" in col]
loghs = loghs[cols]

In [None]:
version=f"COX_{run}"
dump_path = f"{data_post}/{version}"
pathlib.Path(dump_path).mkdir(parents=True, exist_ok=True)

In [None]:
feature_sets = ["Age+Sex", "ASCVDnoblood", "ASCVD", "PANELnoblood", "PANELjustbloodcount", "PANEL"]

In [None]:
paths = [f"{dump_path}/DS_{endpoint}_{features}+Metabolomics_{partition}.p" 
         for features in feature_sets 
         for partition in partitions
        for endpoint in endpoints]

In [None]:
import glob, os
import glob
import pickle
import re
cph_dict = {}
for path in tqdm(paths):
    cph_dict[pathlib.Path(path).stem] = get_cph(path)

In [None]:
from lifelines import utils
from lifelines.fitters import RegressionFitter, SemiParametricRegressionFitter
from lifelines.plotting import set_kwargs_drawstyle
def plot_partial_effects_on_outcome(cph, covariates, values, plot_baseline=True, y="survival_function", **kwargs):
    
        from matplotlib import pyplot as plt

        covariates = utils._to_list(covariates)
        n_covariates = len(covariates)
        values = np.asarray(values)
        if len(values.shape) == 1:
            values = values[None, :].T

        if n_covariates != values.shape[1]:
            raise ValueError("The number of covariates must equal to second dimension of the values array.")

        for covariate in covariates:
            if covariate not in cph._central_values.columns:
                raise KeyError("covariate `%s` is not present in the original dataset" % covariate)
#
        if cph.strata is None:
            data_exp = []
            #axes = kwargs.pop("ax", None) or plt.figure().add_subplot(111)
            x_bar = cph._central_values
            X = pd.concat([x_bar] * values.shape[0])

            if np.array_equal(np.eye(n_covariates), values) or np.array_equal(
                np.append(np.eye(n_covariates), np.zeros((n_covariates, 1)), axis=1), values
            ):
                X.index = ["%s=1" % c for c in covariates]
            else:
                X.index = [", ".join("%s=%s" % (c, v) for (c, v) in zip(covariates, row)) for row in values]
            for covariate, value in zip(covariates, values.T):
                X[covariate] = value

            X = X.astype(cph._central_values.dtypes)

            data_exp = getattr(cph, "predict_%s" % y)(X)

        return data_exp

In [None]:
def get_part_effects_df(endpoint="M_MACE", features="Age+Sex", quantiles=[0.01, 0.1, 0.5, 0.9, 0.99], partition=0):
    cph = cph_dict[f'DS_{endpoint}_{features}+Metabolomics_{partition}']
    values = loghs[f'logh_{endpoint}_Metabolomics'].quantile(q=quantiles).to_list()
    data_exp = plot_partial_effects_on_outcome(cph,covariates=f'logh_{endpoint}_Metabolomics', values=values, cmap='coolwarm')
    return data_exp

def clean_df(df, quantiles):
    df.columns = quantiles+["endpoint", "features", "partition"]
    return df
    
quantiles=[0.01, 0.1, 0.5, 0.9, 0.99]
partial_effects = pd.concat([clean_df(get_part_effects_df(endpoint, features, quantiles, partition).assign(endpoint=endpoint, features=features, partition=partition), quantiles)
                             for endpoint in tqdm(endpoints)
                             for features in feature_sets
                             for partition in partitions
                            ])
partial_effects = partial_effects.reset_index().rename(columns={"index" : "time"})

In [None]:
import plotly.express as px

df = partial_effects.melt(id_vars=["time", "endpoint", "features", "partition"], value_vars=quantiles, var_name="quantile", value_name="surv").assign(Ft=lambda x: 1-x.surv)

In [None]:
df.to_feather(f"{data_results_path}/adj_partial_effects_metabolomics_{run}.feather")

In [None]:
hrs = []
for key, cph in tqdm(cph_dict.items()):
    endpoint = cph.event_col[:-6]
    if "DS" in key:
        if "Metabolomics" in key:
            partition=int(key[-2:].replace("_", ""))
            hrs.append(cph.summary.assign(module="DS", endpoint=endpoint, features=key, partition=partition).reset_index())

In [None]:
hrs_df = pd.concat(hrs, axis=0).reset_index(drop=True)

In [None]:
hrs_df.to_feather(f"{data_results_path}/hrs_metabolomics_{run}.feather")