# Fit Cox Models

In [None]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import clear_output
import ray

import warnings
import lifelines
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

## Submit CoxPH jobs

In [None]:
# %%
import datetime
import itertools
import os
import pathlib
import re
import subprocess

import pandas as pd
from omegaconf import OmegaConf
from tqdm.auto import tqdm

# %% codecell
USER = "USER"  # Anonymized
BASE = pathlib.Path(f"/home/{USER}/code/")

EXPERIMENT_NAME = "22_medhistory"  # name under which to store the generated .sh scripts and yamls
TEMPLATE_CONFIG = f"{BASE}/config/"  # template yaml to use
TRAIN_SCRIPT = f"{BASE}/MedicalHistoryPhenomeWide/2_downstream_processing/08_coxph_fit_partition.py"
ACTIVATE_ENV_CMD = """mamba activate udm"""

TAG = 230425
JOBNAME = "fit_coxph"

In [None]:
TRAIN_SCRIPT

In [None]:
os.makedirs(f"/home/{USER}/tmp/{EXPERIMENT_NAME}/job_submissions", exist_ok=True)
os.makedirs(f"/home/{USER}/tmp/{EXPERIMENT_NAME}/job_configs", exist_ok=True)
os.makedirs(f"/home/{USER}/tmp/{EXPERIMENT_NAME}/job_outputs", exist_ok=True)

In [None]:
def make_job_script(user, job_name, partition):

    job_script_str = (
        f"""#!/bin/bash

#SBATCH --job-name={job_name}  # Specify job name
#SBATCH --nodes=1              # Specify number of nodes
#SBATCH --mem=485G              # Specify number of nodes
#SBATCH --time=4:00:00        # Set a limit on the total run time
#SBATCH --tasks-per-node=1
#SBATCH --exclusive

{ACTIVATE_ENV_CMD}

ray start --head --port=6378 --num-cpus 24
python {TRAIN_SCRIPT} {partition}"""
            )

    return job_script_str


In [None]:
def submit(path, job_name, job_script, time_stamp=None):
    if not time_stamp:
        time_stamp = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

    script_path_long = f"{path}/{job_name}_{time_stamp}.sh"

    with open(script_path_long, "w") as outfile:
        outfile.write(job_script)
    script_path = f"{path}/{job_name}.sh"
    try:
        os.unlink(script_path)
    except FileNotFoundError:  # because we cannot overwrite symlinks directly
        pass
    os.symlink(os.path.realpath(script_path_long), script_path)

    output_path = f"/home/{USER}/tmp/{EXPERIMENT_NAME}/job_outputs/{job_name}"

    print(job_script)
    print("\n\nSubmission:\n===========\n")
    sub_cmd = (
        f"sbatch --error={output_path}_%j_stderr.out --output={output_path}_%j_stdout.out <"
        f" {script_path}"
    )
    print(sub_cmd)

    ret = subprocess.run(sub_cmd, shell=True, cwd=os.getcwd(), capture_output=True)
    print(ret.stdout.decode())

In [None]:
partitions = [i for i in range(19, 22)]
partitions

In [None]:
#!rm -rf /tmp/ray

In [None]:
#!rm -rf /home/USER/tmp/22_medhistory/job_outputs

In [None]:
import time

jobids = []
for partition in partitions:

    job_script = make_job_script(user=USER, job_name=JOBNAME, partition=partition)

    jobid = submit(
        path=f"/home/{USER}/tmp/{EXPERIMENT_NAME}/job_submissions",
        job_name=JOBNAME + f"_{partition}",
        job_script=job_script,
    )

    jobids.append(jobid)
    #time.sleep(2)

print(jobids)

## Check progress

In [None]:
cox_paths = !ls "/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/230425/coxph/models/"

path_df = pd.DataFrame(data = [p[:-2] for p in cox_paths]).rename(columns={0:"path"})
path_df[["endpoint_1", "endpoint_2", "score", "partition"]] = path_df.path.str.split("_", expand=True,)
path_df["endpoint"] = path_df["endpoint_1"] + "_" + path_df["endpoint_2"]
path_df["partition"] = path_df["partition"].astype(int)
path_df.value_counts(["partition"]).to_frame().sort_index()

In [None]:
!ls -t "/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/230425/coxph/predictions" | head -n5

## Fig Crashing CoxPH models

In [None]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

experiment = 220613
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

In [None]:
endpoints_md = pd.read_csv(f"{experiment_path}/endpoints.csv")
endpoints = sorted(endpoints_md.endpoint.to_list())

In [None]:
endpoints_md = pd.read_csv(f"{experiment_path}/endpoints.csv")
endpoints = sorted(endpoints_md.endpoint.to_list())

In [None]:
import yaml
import pickle
import zstandard

def get_score_defs():

    with open(r'/home/USER/code/22_medical_records/1_processing/score_definitions.yaml') as file:
        score_defs = yaml.full_load(file)
    
    return score_defs

def get_features(endpoint, score_defs):
    features = {
        'Identity(Records)+MLP': {
            "Age+Sex": score_defs["AgeSex"],
            "SCORE2": score_defs["SCORE2"],
            "ASCVD": score_defs["ASCVD"],
            "QRISK3": score_defs["QRISK3"],
            "MedicalHistory": [endpoint],
            "Age+Sex+MedicalHistory": score_defs["AgeSex"] + [endpoint],
            }
    }
    return features

def load_pickle(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return data

def fit_cox(data_fit, feature_set, covariates, endpoint, penalizer, step_size=1):
    cph = CoxPHFitter(penalizer=penalizer)
    cph.fit(data_fit, f"{endpoint}_time", f"{endpoint}_event", step_size=step_size)
    return cph

score_defs = get_score_defs()

In [None]:
def clean_covariates(endpoint, covariates):
    if endpoint=="phecode_181": # Autoimmune disease
        covariates = [c for c in covariates if c!="systemic_lupus_erythematosus"]
    if endpoint=="phecode_202": # Diabetes
        covariates = [c for c in covariates if c not in ['diabetes1', 'diabetes2', 'diabetes']]
    if endpoint=="phecode_202-1": # Diabetes 1
        covariates = [c for c in covariates if c!="diabetes1"]
    if endpoint=="phecode_202-2": # Diabetes 1
        covariates = [c for c in covariates if c!="diabetes2"]
    if endpoint=="phecode_286": # Mood [affective] disorders
        covariates = [c for c in covariates if c not in ['bipolar_disorder', 'major_depressive_disorder']]
    if endpoint=="phecode_286-1": # Bipolar disorder
        covariates = [c for c in covariates if c not in ['bipolar_disorder']]
    if endpoint=="phecode_286-2": # Major depressive disorder
        covariates = [c for c in covariates if c not in ['major_depressive_disorder']]
    if endpoint=="phecode_287": # psychotic disorders
        covariates = [c for c in covariates if c not in ['schizophrenia']]
    if endpoint=="phecode_287-1": # schizophrenia
        covariates = [c for c in covariates if c not in ['schizophrenia']]
    if endpoint=="phecode_331": # headache
        covariates = [c for c in covariates if c!="migraine"]
    if endpoint=="phecode_331-6": # headache
        covariates = [c for c in covariates if c!="migraine"]
    if endpoint=="phecode_416": # atrial fibrillation
        covariates = [c for c in covariates if c not in ['atrial_fibrillation']]
    if endpoint=="phecode_416-2": # atrial fibrillation and flutter
        covariates = [c for c in covariates if c not in ['atrial_fibrillation']]
    if endpoint=="phecode_416-21": # atrial fibrillation
        covariates = [c for c in covariates if c not in ['atrial_fibrillation']]
    if endpoint=="phecode_584": # Renal failure
        covariates = [c for c in covariates if c not in ['renal_failure']]
    if endpoint=="phecode_605": # Male sexual dysfuction
        covariates = [c for c in covariates if c not in ['sex_Male', 'male_erectile_dysfunction']]
    if endpoint=="phecode_605-1": # Male sexual dysfuction
        covariates = [c for c in covariates if c not in ['sex_Male', 'male_erectile_dysfunction']]
    if endpoint=="phecode_700": # Diffuse diseases of connective tissue
        covariates = [c for c in covariates if c not in ['systemic_lupus_erythematosus']]
    if endpoint=="phecode_700-1": # Lupus
        covariates = [c for c in covariates if c not in ['systemic_lupus_erythematosus']]
    if endpoint=="phecode_700-11": # Systemic lupus erythematosus [SLE]	
        covariates = [c for c in covariates if c not in ['systemic_lupus_erythematosus']]
    if endpoint=="phecode_705": # Rheumatoid arthritis and other inflammatory
        covariates = [c for c in covariates if c not in ['rheumatoid_arthritis']]
    if endpoint=="phecode_705-1": # Rheumatoid arthritis and other inflammatory
        covariates = [c for c in covariates if c not in ['rheumatoid_arthritis']]
    return covariates

In [None]:
# problematic endpoints
problem_endpoints = [
    "phecode_181",
    "phecode_202",
    "phecode_202-1",
    "phecode_286",
    "phecode_287-1",
    "phecode_331",
    "phecode_416",
    "phecode_416-2",
    "phecode_416-21",
    "phecode_584", # 
    "phecode_605",
    "phecode_700",
    "phecode_700-1",
    "phecode_700-11",
    "phecode_705",
    "phecode_705-1"
]

endpoints_md.query("endpoint==@problem_endpoints")

In [None]:
from lifelines import CoxPHFitter
endpoint = "phecode_202-2"
partition = 11
feature_set = "QRISK3"
display(endpoints_md.query("endpoint==@endpoint"))

test_data = load_pickle(f"/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220613/coxph/errordata_{endpoint}_{feature_set}_{partition}.p")
display(test_data.T)

features = get_features(endpoint, score_defs)
covariates = features["Identity(Records)+MLP"][feature_set]

# # clean covariates for the coxphs to fit
covariates = clean_covariates(endpoint, covariates)
print(covariates)

data_endpoint = test_data[covariates + [f"{endpoint}_event", f"{endpoint}_time"]].astype(np.float32)

cph = fit_cox(data_endpoint,#.drop(columns=["systemic_lupus_erythematosus"]), 
              feature_set, covariates, endpoint, penalizer=0, step_size=0.1)
cph.print_summary()

In [None]:
jobids

In [None]:
#fit_partition(in_path, model_path, score_defs, 0)

In [None]:
import submitit

partitions = [i for i in range(22)]

executor = submitit.AutoExecutor(folder="log_test/%j")
# set timeout in min, and partition for running the job
executor.update_parameters(slurm_array_parallelism=6,
                           nodes=1,
                           #slurm_mem="500G",
                           timeout_min=600,
                          #slurm_setup=[
                          #  """export MKL_NUM_THREADS=1""",
                          #  """export NUMEXPR_NUM_THREADS=1""",
                           # """export OMP_NUM_THREADS=1"""]
                          )

job = executor.map_array(fit_partition, partitions)  


# jobs = []
# for partition in tqdm(partitions):
#     job = executor.submit(fit_partition, in_path, model_path, score_defs, partition)
#     jobs.append(job)

In [None]:
%env MKL_NUM_THREADS=1
%env NUMEXPR_NUM_THREADS=1
%env OMP_NUM_THREADS=1

import ray
#ray start --head --port=6379 --num-cpus 64 # in terminal
#ray.init(address='auto')#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))
#ray.init(num_cpus=32)#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

In [None]:
ray.available_resources()

In [None]:
for partition in tqdm(partitions):
    fit_partition()

In [None]:
## Debugging

In [None]:
# fit_endpoint(data_partition, eids_dict, endpoint_defs, endpoints[0], partition, models, model_path)

In [None]:
cph_1 = load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220613/coxph/models/OMOP_4306655_QRISK3_0.p")
cph_2 = load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220613/coxph/models/OMOP_4306655_Age+Sex+MedicalHistory_0.p")

In [None]:
cph_1.print_summary()

In [None]:
cph_2.print_summary()

In [None]:
encode_cols = 
pd.get_dummies(data_partition['Identity(Records)+MLP'], columns=["ethnic_background", "sex", "smoking_status"], prefix=)

In [None]:
data_partition['Identity(Records)+MLP'].select_dtypes("category")

In [None]:
load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220413/coxph/errordata_phecode_002-1_Age+Sex+MedicalHistory+I(Age*MH)_0.p")

In [None]:
data_partition['Identity(Records)+MLP']['phecode_977']