# Fit Cox Models

## Initialize

In [1]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import clear_output
import ray
import datetime
import subprocess
import warnings
import lifelines
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

In [2]:
base_path = "/home/jakobs"

project_path = f"{base_path}/data"

experiment = '230323'
experiment_path = f"{project_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

partitions = [i for i in range(10)]

In [3]:
os.environ['MKL_NUM_THREADS'] = "1"
os.environ['NUMEXPR_NUM_THREADS'] = "1"
os.environ['OMP_NUM_THREADS'] = "1"

In [4]:
from scripts.coxph_fit_partition import load_data, fit_endpoint

In [5]:
partitions

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [6]:
for partition in tqdm(partitions):
    eids_dict, score_defs, endpoint_defs, endpoints, models, model_path, experiment_path, data_partition = load_data(partition)
    # setup ray and put files in plasma storage
    #ray.init(num_cpus=24) # crashes if num_cpus > 16, why not more possible?
    ray_eids = ray.put(eids_dict)
    ray_score_defs = ray.put(score_defs)
    ray_endpoint_defs = ray.put(endpoint_defs)
    ray_partition = ray.put(data_partition)
    # fit cox models via ray
    progress = []
    for endpoint in endpoints:
        progress.append(fit_endpoint.remote(ray_partition, ray_eids, ray_score_defs, ray_endpoint_defs, endpoint, partition, models, model_path, experiment_path))
    [ray.get(s) for s in tqdm(progress)]

  0%|          | 0/10 [00:00<?, ?it/s]

2023-03-23 15:29:44,875	INFO worker.py:1544 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

  0%|          | 0/368 [00:00<?, ?it/s]

## Check progress

In [7]:
cox_paths = os.listdir(f"{experiment_path}/coxph/models/")
#len(cox_paths), cox_paths[0]
path_df = pd.DataFrame(data = [p[:-2] for p in cox_paths]).rename(columns={0:"path"})
print(path_df.head())
path_df[["endpoint_1", "endpoint_2", "score", "model", "partition"]] = path_df["path"].str.split("_", expand=True)

                                    path
0       phecode_168-1_SCORE2_RetinaUKB_9
1  phecode_582_SCORE2+Retina_RetinaUKB_6
2       phecode_714-3_SCORE2_RetinaUKB_0
3  phecode_477_SCORE2+Retina_RetinaUKB_0
4     phecode_431-12_Age+Sex_RetinaUKB_1


In [8]:
path_df["endpoint"] = path_df["endpoint_1"] + "_" + path_df["endpoint_2"] 

path_df.value_counts(["partition"]).to_frame()

Unnamed: 0_level_0,0
partition,Unnamed: 1_level_1
0,2576
1,2576
2,2576
3,2576
4,2576
5,2576
6,2576
7,2576
8,2576
9,2576


## Fig Crashing CoxPH models

In [None]:
in_path = pathlib.Path(f"{experiment_path}/coxph/input")
models = [f.name for f in in_path.iterdir() if f.is_dir() and "ipynb_checkpoints" not in str(f)]
models

In [None]:
import pandas as pd
endpoints_md = pd.read_csv('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retinal_risk/data/220602/endpoints.csv')
endpoints = sorted([l.replace('_prevalent', '') for l in list(pd.read_csv('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retinal_risk/data/220602/endpoints.csv').endpoint.values)])
endpoints_md

In [None]:
import yaml
import pickle
import zstandard

def get_score_defs():

    with open(r'/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas_220603_fullrun/data/score_definitions.yaml') as file:
        score_defs = yaml.full_load(file)
    
    return score_defs

def get_features(endpoint, score_defs):
    features = {
        model: {
            #"Age+Sex": score_defs["AgeSex"],
            #"Retina": [endpoint],
            #"SCORE2": score_defs["SCORE2"],
            #"ASCVD": score_defs["ASCVD"],
            "QRISK3": score_defs["QRISK3"],
            #"Age+Sex+Retina": score_defs["AgeSex"] + [endpoint],
            #"SCORE2+Retina": score_defs["SCORE2"] + [endpoint],
            #"ASCVD+Retina": score_defs["ASCVD"] + [endpoint],
            "QRISK3+Retina": score_defs["QRISK3"] + [endpoint],
            }
        for model in models}
    return features

def load_pickle(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return data

def fit_cox(data_fit, feature_set, covariates, endpoint, penalizer, step_size=1):
    cph = CoxPHFitter(penalizer=penalizer)
    cph.fit(data_fit, f"{endpoint}_time", f"{endpoint}_event", step_size=step_size)
    return cph

score_defs = get_score_defs()

In [None]:
def clean_covariates(endpoint, covariates):
    if endpoint=="phecode_181": # Autoimmune disease
        covariates = [c for c in covariates if c!="systemic_lupus_erythematosus"]
    if endpoint=="phecode_202": # Diabetes
        covariates = [c for c in covariates if c not in ['diabetes1', 'diabetes2', 'diabetes']]
    if endpoint=="phecode_202-1": # Diabetes 1
        covariates = [c for c in covariates if c!="diabetes1"]
    if endpoint=="phecode_202-2": # Diabetes 1
        covariates = [c for c in covariates if c!="diabetes2"]
    if endpoint=="phecode_286": # Mood [affective] disorders
        covariates = [c for c in covariates if c not in ['bipolar_disorder', 'major_depressive_disorder']]
    if endpoint=="phecode_286-1": # Bipolar disorder
        covariates = [c for c in covariates if c not in ['bipolar_disorder']]
    if endpoint=="phecode_286-2": # Major depressive disorder
        covariates = [c for c in covariates if c not in ['major_depressive_disorder']]
    if endpoint=="phecode_287": # psychotic disorders
        covariates = [c for c in covariates if c not in ['schizophrenia']]
    if endpoint=="phecode_287-1": # schizophrenia
        covariates = [c for c in covariates if c not in ['schizophrenia']]
    if endpoint=="phecode_331": # headache
        covariates = [c for c in covariates if c!="migraine"]
    if endpoint=="phecode_331-6": # headache
        covariates = [c for c in covariates if c!="migraine"]
    if endpoint=="phecode_416": # atrial fibrillation
        covariates = [c for c in covariates if c not in ['atrial_fibrillation']]
    if endpoint=="phecode_416-2": # atrial fibrillation and flutter
        covariates = [c for c in covariates if c not in ['atrial_fibrillation']]
    if endpoint=="phecode_416-21": # atrial fibrillation
        covariates = [c for c in covariates if c not in ['atrial_fibrillation']]
    if endpoint=="phecode_584": # Renal failure
        covariates = [c for c in covariates if c not in ['renal_failure']]
    if endpoint=="phecode_605": # Male sexual dysfuction
        covariates = [c for c in covariates if c not in ['sex_Male', 'male_erectile_dysfunction']]
    if endpoint=="phecode_605-1": # Male sexual dysfuction
        covariates = [c for c in covariates if c not in ['sex_Male', 'male_erectile_dysfunction']]
    if endpoint=="phecode_700": # Diffuse diseases of connective tissue
        covariates = [c for c in covariates if c not in ['systemic_lupus_erythematosus']]
    if endpoint=="phecode_700-1": # Lupus
        covariates = [c for c in covariates if c not in ['systemic_lupus_erythematosus']]
    if endpoint=="phecode_700-11": # Systemic lupus erythematosus [SLE]	
        covariates = [c for c in covariates if c not in ['systemic_lupus_erythematosus']]
    if endpoint=="phecode_705": # Rheumatoid arthritis and other inflammatory
        covariates = [c for c in covariates if c not in ['rheumatoid_arthritis']]
    if endpoint=="phecode_705-1": # Rheumatoid arthritis and other inflammatory
        covariates = [c for c in covariates if c not in ['rheumatoid_arthritis']]
    # added by lukas
    if endpoint=='phecode_620':
        covariates = [c for c in covariates if c not in ['sex_Male', 'male_erectile_dysfunction']]
    if endpoint=='phecode_627':
        covariates = [c for c in covariates if c not in ['sex_Male', 'male_erectile_dysfunction']]
    if endpoint=='phecode_627-4':
        covariates = [c for c in covariates if c not in ['sex_Male', 'male_erectile_dysfunction']]
    return covariates

In [None]:
# problematic endpoints
problem_endpoints = [
    'phecode_620',
    'phecode_627',
    'phecode_627-4',
    #"phecode_181",
    #"phecode_202",
    #"phecode_202-1",
    #"phecode_286",
    #"phecode_287-1",
    #"phecode_331",
    #"phecode_416",
    #"phecode_416-2",
    #"phecode_416-21",
    #"phecode_584", # 
    #"phecode_605",
    #"phecode_700",
    #"phecode_700-1",
    #"phecode_700-11",
    #"phecode_705",
    #"phecode_705-1"
]

endpoints_md.query("endpoint==@problem_endpoints")

In [None]:
from lifelines import CoxPHFitter
endpoint = problem_endpoints[0]
partition = 20
feature_set = "QRISK3" # QRISK3+Retina
display(endpoints_md.query("endpoint==@endpoint"))

test_data = load_pickle(f"{experiment_path}/coxph/errordata_{endpoint}_{feature_set}_{partition}.p")
display(test_data.T)

features = get_features(endpoint, score_defs)
print(features)
covariates = features["ImageTraining_[]_ConvNeXt_MLPHead_predictions_cropratio0.3"][feature_set]

# # clean covariates for the coxphs to fit
covariates = clean_covariates(endpoint, covariates)
print(covariates)

data_endpoint = test_data[covariates + [f"{endpoint}_event", f"{endpoint}_time"]].astype(np.float32)

cph = fit_cox(data_endpoint,#.drop(columns=["systemic_lupus_erythematosus"]), 
              feature_set, covariates, endpoint, penalizer=0, step_size=0.1)
cph.print_summary()

In [None]:
jobids

In [None]:
#fit_partition(in_path, model_path, score_defs, 0)

In [None]:
import submitit

partitions = [i for i in range(22)]

executor = submitit.AutoExecutor(folder="log_test/%j")
# set timeout in min, and partition for running the job
executor.update_parameters(slurm_array_parallelism=6,
                           nodes=1,
                           #slurm_mem="500G",
                           timeout_min=600,
                          #slurm_setup=[
                          #  """export MKL_NUM_THREADS=1""",
                          #  """export NUMEXPR_NUM_THREADS=1""",
                           # """export OMP_NUM_THREADS=1"""]
                          )

job = executor.map_array(fit_partition, partitions)  


# jobs = []
# for partition in tqdm(partitions):
#     job = executor.submit(fit_partition, in_path, model_path, score_defs, partition)
#     jobs.append(job)

In [None]:
%env MKL_NUM_THREADS=1
%env NUMEXPR_NUM_THREADS=1
%env OMP_NUM_THREADS=1

import ray
#ray start --head --port=6379 --num-cpus 64 # in terminal
#ray.init(address='auto')#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))
#ray.init(num_cpus=32)#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

In [None]:
ray.available_resources()

In [None]:
for partition in tqdm(partitions):
    fit_partition()

In [None]:
## Debugging

In [None]:
# fit_endpoint(data_partition, eids_dict, endpoint_defs, endpoints[0], partition, models, model_path)

In [None]:
cph_1 = load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220613/coxph/models/OMOP_4306655_QRISK3_0.p")
cph_2 = load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220613/coxph/models/OMOP_4306655_Age+Sex+MedicalHistory_0.p")

In [None]:
cph_1.print_summary()

In [None]:
cph_2.print_summary()

In [None]:
encode_cols = 
pd.get_dummies(data_partition['Identity(Records)+MLP'], columns=["ethnic_background", "sex", "smoking_status"], prefix=)

In [None]:
data_partition['Identity(Records)+MLP'].select_dtypes("category")

In [None]:
load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220413/coxph/errordata_phecode_002-1_Age+Sex+MedicalHistory+I(Age*MH)_0.p")

In [None]:
data_partition['Identity(Records)+MLP']['phecode_977']

# old stuff

In [None]:
!ls -al {output_path}

In [None]:
print(output_path)
data_outcomes = pd.read_feather(f"{output_path}/baseline_outcomes_220531.feather").set_index("eid")
data_outcomes

In [None]:
import pandas as pd
all_endpoints = sorted([l.replace('_prevalent', '') for l in list(pd.read_csv('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retinal_risk/data/220602/endpoints.csv').endpoint.values)])

#all_endpoints = sorted(endpoints_all_md.endpoint.to_list())
print(len(all_endpoints))

endpoints_not_overlapping_with_preds = []
#endpoints_not_overlapping_with_preds_md = pd.read_csv(f"{experiment_path}/endpoints_not_overlapping.csv", header=None)
#print(len(endpoints_not_overlapping_with_preds_md))
#endpoints_not_overlapping_with_preds = list(endpoints_not_overlapping_with_preds_md[0].values)

endpoints = []
for c in all_endpoints:
    if c not in endpoints_not_overlapping_with_preds: # this is what i want
        #print('OK    - ',c)
        endpoints.append(c)
    #if c in endpoints_not_overlapping_with_preds: # this is what causes errors!
    #    print('ERROR - ',c)
print(len(endpoints))

In [None]:
splits = ["train", "valid", 'test'] # "test_left", 'test_right'

In [None]:
endpoint_defs = pd.read_feather(f"{output_path}/phecode_defs_220306.feather").query("endpoint==@endpoints").sort_values("endpoint").set_index("endpoint")

In [None]:
from datetime import date
today = str(date.today())

In [None]:
eligable_eids = pd.read_feather(f"{output_path}/eligable_eids_{today}.feather") # TODO CHANGE!
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [None]:
%env MKL_NUM_THREADS=4
%env NUMEXPR_NUM_THREADS=4
%env OMP_NUM_THREADS=4

In [None]:
ray.shutdown()

In [None]:
import ray

ray.init(num_cpus=24)#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

In [None]:
AgeSex = ["age_at_recruitment_f21022_0_0", "sex_f31_0_0"]

# Train COX

In [None]:
in_path = pathlib.Path(f"{experiment_path}/coxph/input")
in_path.mkdir(parents=True, exist_ok=True)

model_path = f"{experiment_path}/coxph/models"
pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)

In [None]:
models = [f.name for f in in_path.iterdir() if f.is_dir() and "ipynb_checkpoints" not in str(f)]
models

In [None]:
from formulaic.errors import FactorEvaluationError

In [None]:
in_path

In [None]:
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError
import zstandard
import pickle

def get_features(endpoint):
    features = {
        models[0]: { # TODO CHANGE!
            "Age+Sex": AgeSex,
            "Retina": [endpoint],
            "Age+Sex+Retina": AgeSex + [endpoint],
            #"Age+Sex+MedicalHistory+I(Age*MH)": AgeSex + [endpoint]
            }
    }
    return features

def get_train_data(in_path, partition, models, mapping):
    train_data = {
        model: pd.read_feather(f"{in_path}/{model}/{partition}/train.feather").set_index("eid").merge(data_outcomes, left_index=True, right_index=True, how="left").replace(mapping)
    for model in models}
    
    return train_data

def fit_cox(data_fit, feature_set, covariates, endpoint, penalizer, step_size=1):
    if feature_set=="Age+Sex+MedicalHistory+I(Age*MH)":
        endpoint_label = endpoint.replace("-", "")
        data_fit.columns = [c.replace("-", "") for c in data_fit.columns]
        covariates = [c.replace("-", "") for c in covariates]
        #print(endpoint_label)
        #print(data_fit)
        #print(covariates)
        if "sex_f31_0_0" in covariates:
            formula=f"age_at_recruitment_f21022_0_0*{endpoint_label}+sex_f31_0_0*{endpoint_label}"
        else:
            formula=f"age_at_recruitment_f21022_0_0*{endpoint_label}"
        cph = CoxPHFitter(penalizer=penalizer)
        cph.fit(data_fit, f"{endpoint_label}_time", f"{endpoint_label}_event", formula=formula, step_size=step_size)
    else:
        cph = CoxPHFitter(penalizer=penalizer)
        cph.fit(data_fit, f"{endpoint}_time", f"{endpoint}_event", step_size=step_size)

    return cph

def save_pickle(data, data_path):
    with open(data_path, "wb") as fh:
        cctx = zstandard.ZstdCompressor()
        with cctx.stream_writer(fh) as compressor:
            compressor.write(pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL))
            
def load_pickle(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return data

@ray.remote
def fit_endpoint(data_partition, eids_dict, endpoint_defs, endpoint, partition, models, model_path):
    eids_incl = eids_dict[endpoint].tolist()
    features = get_features(endpoint)
    eligibility = endpoint_defs.loc[endpoint]["sex"]
    for model in models:
        data_model = data_partition[model]
        for feature_set, covariates in features[model].items():
            cph_path = f"{model_path}/{endpoint}_{feature_set}_{partition}.p"
            if os.path.isfile(cph_path):
                try:
                    cph = load_pickle(cph_path)
                    success = True
                except:
                    success = False
                    pass
            if not os.path.isfile(cph_path) or success==False:
                if (eligibility != "Both") and ("sex_f31_0_0" in covariates): 
                    covariates = [c for c in covariates if c!="sex_f31_0_0"]
                #print('covariates:', covariates)
                data_endpoint = data_model[covariates + [f"{endpoint}_event", f"{endpoint}_time"]].astype(np.float32)
                data_endpoint = data_endpoint[data_endpoint.index.isin(eids_incl)]
                try:
                    cph = fit_cox(data_endpoint, feature_set, covariates, endpoint, penalizer=0.0)
                    save_pickle(cph, cph_path)
                except (ValueError, ConvergenceError, KeyError,FactorEvaluationError) as e:
                    print("ConvergenceError", model, endpoint, feature_set, partition, "problem: reduce step size")
                    try:
                        cph = fit_cox(data_endpoint, feature_set, covariates, endpoint, penalizer=0.0, step_size=0.5)
                        save_pickle(cph, cph_path)
                        print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.5 successfull")
                    except (ValueError, ConvergenceError, KeyError,FactorEvaluationError) as e:
                        print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.5 failed")
                        try:
                            cph = fit_cox(data_endpoint, feature_set, covariates, endpoint, penalizer=0.0, step_size=0.1)
                            save_pickle(cph, cph_path)
                            print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.1 successfull")
                        except (ValueError, ConvergenceError, KeyError, FactorEvaluationError) as e:
                            print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.1 failed")
                            save_pickle(data_endpoint, f"{experiment_path}/coxph/errordata_{endpoint}_{feature_set}_{partition}.p")
                            pass
    return True

In [None]:
f"{experiment_path}/coxph"

In [None]:
model_list =  !ls $model_path
#model_list = [m for m in model_list if "I(" in m]
model_list = [m for m in model_list]

In [None]:
model_list

In [None]:
1+1

In [None]:
mapping = {"sex_f31_0_0": {"Female":0, "Male":1}}

ray_eids = ray.put(eids_dict)
ray_endpoint_defs = ray.put(endpoint_defs)
for partition in tqdm([0]): # in tqdm(partitions) # TODO: CHANGE!
    try:
        del ray_partition
    except:
        print("Ray object not yet initialised")
    try:
        data_partition = get_train_data(in_path, partition, models, mapping)
        ray_partition = ray.put(data_partition)
        progress = []
        for endpoint in endpoints:
            progress.append(fit_endpoint.remote(ray_partition, ray_eids, ray_endpoint_defs, endpoint, partition, models, model_path))
        [ray.get(s) for s in tqdm(progress)]
    except FileNotFoundError:
        print('file not found')
        pass

In [None]:
load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/test_experiment/coxph/models/phecode_841_Retina_0.p")

In [None]:
data_partition['Identity(Records)+MLP']['phecode_977']