In [None]:
import pandas as pd
import numpy as np
import pathlib
from tqdm.auto import tqdm

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
from torch_geometric import seed_everything
import pathlib

import ray

In [None]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

experiment = 220627
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

In [None]:
attribution_df = pd.read_feather(f"{experiment_path}/attributions_pre.feather")
records = attribution_df.record.unique().tolist()

In [None]:
endpoint_defs = pd.read_feather(f"{output_path}/phecode_defs_220306.feather").sort_values("endpoint")
endpoints_md = pd.read_csv(f"{experiment_path}/endpoints.csv")
endpoints = sorted(endpoints_md.endpoint.to_list())

In [None]:
data_outcomes = pd.read_feather(f"{output_path}/baseline_outcomes_220627.feather").set_index("eid")
data_records = pd.read_feather(f"{output_path}/baseline_records_220627.feather", columns=["eid"] + records).set_index("eid")
data_all = data_records.merge(data_outcomes, left_index=True, right_index=True, how="left")

In [None]:
eligable_eids = pd.read_feather(f"{output_path}/eligable_eids_220627.feather")
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [None]:
ray.shutdown()

In [None]:
import ray

ray.init(num_cpus=20, include_dashboard=False)

In [None]:
@ray.remote
def calc_ratio(data_all, eids_dict, record, eids_record, endpoints):
    r_ds = []
    
    for endpoint in endpoints:
        eids_endpoint = eids_dict[endpoint]
        
        # record set
        eid_idxs_dict = {}
        eid_idxs_dict["record"] = np.where(np.in1d(eids_endpoint, eids_record, assume_unique=True))[0]

        for key, eid_idxs in eid_idxs_dict.items():
            eids_temp = eids_endpoint[eid_idxs]
            s = data_all[f"{endpoint}_event"].loc[eids_temp]
            n=s.sum()
            freq = n/len(s)
            
            if key=="record":
                s_record
                n_record = n
                freq_record = freq
        
        r_ds.append({"endpoint": endpoint, "n_eligable": len(eids_dict[endpoint]), 
                  "record": record, "n_records": len(eids_record), 
                  "n_events_record": n_record, "freq_events_record": freq_record})
    return r_ds

In [None]:
record_freqs = data_records.mean().sort_values(ascending=False)
record_freqs

In [None]:
d_nested = []
ref_data_all = ray.put(data_all)
ref_eids_dict = ray.put(eids_dict)
for record in tqdm(record_freqs.index):
    s_record = data_all[record]
    eids_record = s_record[s_record==True].index.values
    ref_results = calc_ratio.remote(ref_data_all, ref_eids_dict, record, eids_record, endpoints)
    d_nested.append(ref_results)
d_nested = [ray.get(e) for e in tqdm(d_nested)]
del ref_data_all
del ref_eids_dict

In [None]:
from itertools import chain

d = list(chain(*d_nested))

In [None]:
endpoints_freqs = pd.DataFrame().from_dict(d)

In [None]:
endpoints_freqs

In [None]:
endpoints_freqs.to_feather(f"{experiment_path}/attributions_conditional_eventrates.feather")