# Benchmarks

## Initialize

In [None]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import clear_output

import warnings
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

In [None]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

experiment = 230425
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

In [None]:
!ls -t "/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/230425/coxph/models" | head -n5

In [None]:
endpoints_md = pd.read_csv(f"{experiment_path}/endpoints.csv")
endpoints = sorted(endpoints_md.endpoint.to_list())

In [None]:
partitions = [p for p in range(0, 22)]
splits = ["train", "valid", "test"]

In [None]:
endpoint_defs = pd.read_feather(f"{output_path}/phecode_defs_220306.feather").query("endpoint==@endpoints").sort_values("endpoint").set_index("endpoint")

In [None]:
eligable_eids = pd.read_feather(f"{output_path}/eligable_eids_220627.feather")
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [None]:
%env MKL_NUM_THREADS=2
%env NUMEXPR_NUM_THREADS=2
%env OMP_NUM_THREADS=2

In [None]:
ray.shutdown()

In [None]:
import ray
#ray start --head --port=6379 --num-cpus 64 # in terminal
ray.init(address='auto')#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))
#ray.init(num_cpus=32)#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

# Predict COX

In [None]:
in_path = pathlib.Path(f"{output_path}/{experiment}/coxph/input")
model_path = f"{experiment_path}/coxph/models"

out_path = f"{experiment_path}/coxph/predictions"
pathlib.Path(out_path).mkdir(parents=True, exist_ok=True)

In [None]:
model_path = f"{experiment_path}/coxph/models"

In [None]:
model_path

In [None]:
models = ['Identity(Records)+MLP', 'Identity(Records)+Linear']

In [None]:
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError
import zstandard
import pickle
import os

def get_score_defs():

    with open(r'/home/USER/code/MedicalHistoryPhenomeWide/2_downstream_processing/score_definitions.yaml') as file:
        score_defs = yaml.full_load(file)
    
    return score_defs

def get_features(endpoint, score_defs):
    features = {
        'Identity(Records)+MLP': {
            "MedicalHistory": [endpoint],
            "Age+Sex": score_defs["AgeSex"],
            "Comorbidities": score_defs["Comorbidities"],
            "SCORE2": score_defs["SCORE2"],
            "ASCVD": score_defs["ASCVD"],
            "QRISK3": score_defs["QRISK3"],
            "Age+Sex+Comorbidities": score_defs["AgeSex"] + score_defs["Comorbidities"],
            "Age+Sex+MedicalHistory": score_defs["AgeSex"] + [endpoint],
            "SCORE2+MedicalHistory": score_defs["SCORE2"] + [endpoint],
            "ASCVD+MedicalHistory": score_defs["ASCVD"] + [endpoint],
            "QRISK3+MedicalHistory": score_defs["QRISK3"] + [endpoint],
            "Age+Sex+Comorbidities+MedicalHistory": score_defs["AgeSex"] + score_defs["Comorbidities"] + [endpoint],
            },
        'Identity(Records)+Linear': {
            "MedicalHistoryLM": [endpoint],
            "Age+Sex+MedicalHistoryLM": score_defs["AgeSex"] + [endpoint],
            "Age+Sex+Comorbidities+MedicalHistoryLM": score_defs["AgeSex"] + score_defs["Comorbidities"] + [endpoint],
            }
    }
    return features

def get_test_data(in_path, partition, models):
    test_data = {model: pd.read_feather(f"{in_path}/{model}/{partition}/test.feather").set_index("eid") for model in models}
    return test_data
            
def load_pickle(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return data

def predict_cox(cph, data_endpoint, endpoint, feature_set, partition, pred_path):
    times = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    time_cols = {t: f"Ft_{t}" for t in times}
    
    if feature_set=="Age+Sex+MedicalHistory+I(Age*MH)":
        data_endpoint.columns = [c.replace("-", "") for c in data_endpoint.columns]
    
    surv_test = 1-cph.predict_survival_function(data_endpoint, times=times) 
    temp_pred = data_endpoint.reset_index()[["eid"]].assign(endpoint=endpoint, features=feature_set, partition=partition)
    for t, col in time_cols.items(): temp_pred[col] = surv_test.T[t].to_list()
    
    temp_pred.to_feather(f"{out_path}/{endpoint}_{feature_set}_{partition}.feather")

@ray.remote
def predict_endpoint(data_partition, eids_dict, endpoint, partition, models, features, model_path, out_path):
    eids_incl = eids_dict[endpoint].tolist()
    results = []
    for model in models:
        data_model = data_partition[model]
        for feature_set, covariates in features[model].items():
            identifier = f"{endpoint}_{feature_set}_{partition}"
            pred_path = f"{out_path}/{identifier}.feather"
            if not os.path.isfile(pred_path):
                try:
                    cph = load_pickle(f"{model_path}/{identifier}.p")
                    data_endpoint = data_model[data_model.index.isin(eids_incl)]
                    predict_cox(cph, data_endpoint, endpoint, feature_set, partition, pred_path)
                except:
                    print(f"{identifier} not available")
    return True

In [None]:
import yaml
score_defs = get_score_defs()

ray_eids = ray.put(eids_dict)
for partition in tqdm(partitions):
    try:
        del ray_partition
    except:
        print("Ray object not yet initialised")
    ray_partition = ray.put(get_test_data(in_path, partition, models))
    progress = []
    for endpoint in endpoints:
        features = get_features(endpoint, score_defs)
        progress.append(predict_endpoint.remote(ray_partition, ray_eids, endpoint, partition, models, features, model_path, out_path))
    [ray.get(s) for s in tqdm(progress)]

In [None]:
1+1