# Benchmarks

## Initialize

In [None]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import clear_output

import warnings
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

In [None]:
base_path = "/home/jakobs"

project_path = f"{base_path}/data"

experiment = '230321'
experiment_path = f"{project_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

partitions = [i for i in range(10)]

today = '230321'

In [None]:
model_path = f"{experiment_path}/coxph/models"
model_list =  !ls $model_path

In [None]:
endpoint_defs = pd.read_feather(f"{base_path}/data/endpoints_epic_md.feather").set_index("endpoint")
endpoints = endpoint_defs.index.to_list()

In [None]:
splits = ["train", 'test'] # "test_left", 'test_right']

In [None]:
eligable_eids = pd.read_feather(f"{experiment_path}/eligible_eids_{today}.feather")
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [None]:
%env MKL_NUM_THREADS=4
%env NUMEXPR_NUM_THREADS=4
%env OMP_NUM_THREADS=4

In [None]:
import ray
ray.shutdown()
#ray start --head --port=6379 --num-cpus 64
ray.init(address='auto')
#ray.init(num_cpus=24)#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

# Predict COX

In [None]:
in_path = pathlib.Path(f"{experiment_path}/coxph/input")
model_path = f"{experiment_path}/coxph/models"

out_path = f"{experiment_path}/coxph/predictions"
pathlib.Path(out_path).mkdir(parents=True, exist_ok=True)

In [None]:
in_path

In [None]:
models = ["RetinaUKB"]
models

In [None]:
#AgeSex = ["age_at_recruitment_f21022_0_0", "sex_f31_0_0"]

In [None]:
from scripts.coxph_fit_partition import load_data, get_score_defs, get_features

In [None]:
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError
import zstandard
import pickle
import os

#def get_test_data(in_path, partition, models, mapping):
def get_test_data(in_path, partition, models):
    data = {model: pd.read_feather(f"{in_path}/{model}/{partition}/test.feather").set_index("eid")#.replace(mapping)
            for model in models}
    return data
            
def load_pickle(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return data

def predict_cox(cph, data_endpoint, endpoint, feature_set, partition, pred_path, model):
    times = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    time_cols = {t: f"Ft_{t}" for t in times}
    
    if feature_set=="Age+Sex+MedicalHistory+I(Age*MH)":
        data_endpoint.columns = [c.replace("-", "") for c in data_endpoint.columns]
    
    surv_test = 1-cph.predict_survival_function(data_endpoint, times=times)
    temp_pred = data_endpoint.reset_index()[["eid"]].assign(endpoint=endpoint, features=feature_set, partition=partition)
    for t, col in time_cols.items(): 
        temp_pred[col] = surv_test.T[t].to_list()
    
    temp_pred.to_feather(f"{out_path}/{endpoint}_{feature_set}_{model}_{partition}.feather") 

@ray.remote
def predict_endpoint(data_partition, eids_dict, endpoint, partition, models, features, model_path, out_path):
    #data_partition_left, data_partition_right = data_partition
    eids_incl = eids_dict[endpoint].tolist()
    results = []
    for model in models:
        data_model = data_partition[model]
        #data_model_left = data_partition_left[model]
        #data_model_right = data_partition_right[model]
        for feature_set, covariates in features[model].items():
            identifier = f"{endpoint}_{feature_set}_{model}_{partition}"
            pred_path = f"{out_path}/{identifier}.feather"
            if not os.path.isfile(pred_path):
                try:
                    cph = load_pickle(f"{model_path}/{identifier}.p")
                    data_endpoint = data_model[data_model.index.isin(eids_incl)]
                    #data_endpoint_left = data_model_left[data_model_left.index.isin(eids_incl)]
                    #data_endpoint_right = data_model_right[data_model_right.index.isin(eids_incl)]
                    predict_cox(cph, data_endpoint, endpoint, feature_set, partition, pred_path, model)
                    #predict_cox_both_eyes(cph, data_endpoint_left, data_endpoint_right, endpoint, feature_set, partition, pred_path)
                except FileNotFoundErrorundError:
                    print(f"{identifier} not available")
    return True

In [None]:
import yaml

#mapping = {"sex_f31_0_0": {"Female":0, "Male":1}}
score_defs = get_score_defs()

ray_eids = ray.put(eids_dict)
for partition in tqdm(partitions):
    try:
        del ray_partition
    except:
        print("Ray object not yet initialised")
    #ray_partition = ray.put(get_test_data(in_path, partition, models, mapping))
    ray_partition = ray.put(get_test_data(in_path, partition, models))
    progress = []
    for endpoint in endpoints:
        features = get_features(endpoint, score_defs, models)
        progress.append(predict_endpoint.remote(ray_partition, ray_eids, endpoint, partition, models, features, model_path, out_path))
    [ray.get(s) for s in tqdm(progress)]