# Benchmarks

## Initialize

In [None]:
%load_ext autoreload
%autoreload 2

import os
from tqdm.auto import tqdm
import pathlib

import numpy as np
import pandas as pd
import lifelines

In [None]:
%env MKL_NUM_THREADS=1
%env NUMEXPR_NUM_THREADS=1
%env OMP_NUM_THREADS=1

In [None]:
#ray.shutdown()

In [None]:
import ray
ray.init(num_cpus=30, dashboard_port=24763, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0")

In [None]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_name = "210714_metabolomics"
#data_path = "/data/analysis/ag-reils/steinfej"
data_pre = f"{base_path}/data/2_datasets_pre/{project_name}"
data_post = f"{base_path}/data/3_datasets_post/{project_name}"

project_label = "21_metabolomics_multitask"
project_path = f"{base_path}/results/projects/{project_label}"
figures_path = f"{project_path}/figures"
data_results_path = f"{project_path}/data"
pathlib.Path(figures_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(data_results_path).mkdir(parents=True, exist_ok=True)

In [None]:
# esclude heart failure, venous thrombosis, aortic anyeurism partition21

In [None]:
run = "220126"

In [None]:
data =  pd.read_feather(f"{data_post}/data_merged.feather")

In [None]:
preds_models = pd.read_feather(f"{data_results_path}/predictions_{run}_metabolomics.feather")

In [None]:
endpoints = preds_models.endpoint.unique().tolist()
partitions = preds_models.partition.unique().tolist()

In [None]:
data_temp = pd.read_feather(f"{data_post}/data_merged.feather")
eids_dict = {}
for endpoint in tqdm(endpoints):
    eids_incl = data_temp.query(f"NMR_FLAG==True&{endpoint}==False").eid.to_list()
    if endpoint == "M_MACE": eids_incl = data_temp.copy().query(f"NMR_FLAG==True&{endpoint}==False&statins==False").eid.to_list()
    elif endpoint == "M_breast_cancer": eids_incl = data_temp.copy().query(f"NMR_FLAG==True&{endpoint}==False&sex=='Female'").eid.to_list()
    elif endpoint == "M_ovarian_cancer": eids_incl = data_temp.copy().query(f"NMR_FLAG==True&{endpoint}==False&sex=='Female'").eid.to_list()
    elif endpoint == "M_uterus_cancer": eids_incl = data_temp.copy().query(f"NMR_FLAG==True&{endpoint}==False&sex=='Female'").eid.to_list()
    elif endpoint == "M_prostate_cancer": eids_incl = data_temp.copy().query(f"NMR_FLAG==True&{endpoint}==False&sex=='Male'").eid.to_list()
    print(endpoint, len(eids_incl))
    eids_dict[endpoint] = eids_incl

In [None]:
preds = preds_models.reset_index(drop=True)

In [None]:
endpoints = preds.endpoint.unique().tolist()
endpoint_labels = sorted([f"{e}_event" for e in endpoints]+[f"{e}_event_time" for e in endpoints])
endpoint_data =  pd.read_feather(f"{data_post}/data_merged.feather", columns=["eid"] + endpoint_labels)

In [None]:
preds.partition.unique()

In [None]:
data_test = preds[['eid','endpoint', 'module','features','partition','Ft_10']]#.query("split=='test'")
data_test

modules = data_test.module.unique().tolist()
features = data_test.features.unique().tolist()
partitions = data_test.partition.unique().tolist()

In [None]:
iterations=[i for i in range(1000)]

In [None]:
data_nmr = data.query("NMR_FLAG==True")

In [None]:
data_nmr.value_counts("ethnic_background")

In [None]:
def age_bins(age):
    if age < 50: return "<50"
    if age>=50:
        if age<=60: return "50-60"
        if age>60: return ">60" 
data_nmr["age"] = data_nmr["age_at_recruitment"].apply(age_bins)

In [None]:
eids_dict_sg = {c: data_nmr.groupby(c)["eid"].apply(list).to_dict() for c in ["age", "sex", "ethnic_background"]}

In [None]:
for group, subgroups in eids_dict_sg.items():
    print(group)
    for subgroup, eids in subgroups.items():
        print(subgroup, len(eids))

In [None]:
from IPython.display import clear_output
#from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc, integrated_brier_score
from lifelines.utils import concordance_index
from dask.diagnostics import ProgressBar

def calculate_per_endpoint(df, endpoint, module, feature, group, subgroup, len_sg, iteration, time):  
    event = [0 if (endpoint_event == 0) | (endpoint_event_time > time) else 1 
             for endpoint_event, endpoint_event_time in zip(df[endpoint+"_event"], df[endpoint+"_event_time"])]
    event_time = [time if (endpoint_event == 0) | (endpoint_event_time > time) else endpoint_event_time 
                  for endpoint_event, endpoint_event_time in zip(df[endpoint+"_event"], df[endpoint+"_event_time"])]
    df = df.assign(event = event, event_time = event_time)
    df = df.dropna(subset=["event_time", f"Ft_{time}", "event"], axis=0)
    
    try:
        cindex = 1-concordance_index(df["event_time"], df[f"Ft_{time}"], df["event"])
    except: cindex=np.nan
    return {"endpoint":endpoint, "module": module, "features": feature, "group": group, "subgroup": subgroup, "len_sg": len_sg, "iteration": iteration, "time":time, "cindex":cindex}

@ray.remote
def calc_per_iteration(data_bm, eids_bs, endpoint, modules, features, group, subgroup, len_sg, iteration):
    rows = []
    print(group, subgroup)
    for module in tqdm(modules, desc=f"{endpoint} ({iteration})"): 
            temp_module = data_bm.query("module==@module")
            for feature in features:
                temp_features = temp_module.query("features==@feature")
                if len(temp_features)>0:
                    data_object = temp_features[["eid", "Ft_10", f"{endpoint}_event", f"{endpoint}_event_time"]].set_index("eid").loc[eids_bs].reset_index()
                    rows.append(calculate_per_endpoint(data_object, endpoint, module, feature, group, subgroup, len_sg, iteration, time=10))
    return rows

In [None]:
rows = []
for endpoint in tqdm(endpoints):
    data_bm = data_test.set_index("eid").query("endpoint==@endpoint").merge(
        endpoint_data[["eid", f"{endpoint}_event", f"{endpoint}_event_time"]].set_index("eid"), 
        left_index=True, right_index=True, how="left").reset_index()
    for group, subgroups in eids_dict_sg.items():
        for subgroup, eids_sg in subgroups.items():
            data_sg = data_bm.query("eid==@eids_sg")
            data_sg_id = ray.put(data_sg)
            eids_sg = data_sg.eid.unique()
            for iteration in iterations: 
                try:
                    eids_bs = np.random.choice(eids_sg, size=len(eids))
                    rows.extend([calc_per_iteration.remote(data_sg_id, eids_bs, endpoint, modules, features, group, subgroup, len(eids_sg), iteration)])
                except:
                    print(endpoint, group, subgroup)

In [None]:
rows_finished = [y for x in ray.get(rows) for y in x]

In [None]:
benchmark_endpoints_pp = pd.DataFrame({}).append(rows_finished, ignore_index=True)
clear_output()

In [None]:
run="220128"
name = f"benchmark1000_cindex_subgroups_{run}"
benchmark_endpoints_pp.to_feather(f"{data_results_path}/{name}.feather")