In [None]:
import pandas as pd
import numpy as np
import pathlib
from tqdm.auto import tqdm

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
from torch_geometric import seed_everything

import ray

In [None]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

## Get Data

In [None]:
from hydra import compose, initialize
from omegaconf import OmegaConf
hydra.core.global_hydra.GlobalHydra().clear()

initialize(config_path="../../ehrgraphs/config")
args = compose(config_name="config", overrides=["datamodule.partition=0", 
                                                "datamodule.use_top_n_phecodes=10000",
                                                "setup.use_data_artifact_if_available=False",
                                                "datamodule/covariates='no_covariates'",
                                                "datamodule.t0_mode=recruitment",
                                                "+use_phecode_inputs=False",
                                                
                                               ])
print(OmegaConf.to_yaml(args))

In [None]:
from ehrgraphs.training import setup_training

seed_everything(0)

datamodule, _, _ = setup_training(args)

In [None]:
def extract_records_events_times(datamodule):
    
    records_list = []
    outcomes_list = []
    
    # prepare extraction
    record_cols = datamodule.record_cols_input
    label_cols = list(datamodule.label_mapping.keys())
    
    for s in tqdm(["train", "valid", "test"]):
        eids = datamodule.eids[s]
        
        if s=="train":  dataset = datamodule.train_dataloader(shuffle=False, drop_last=False).dataset
        if s=="valid":  dataset = datamodule.val_dataloader().dataset
        if s=="test":  dataset = datamodule.test_dataloader().dataset

        # extract records
        records_temp = pd.DataFrame.sparse.from_spmatrix(dataset.records, index=eids, columns=[f"{c}" for c in record_cols]).rename_axis("eid")
        records_list.append(records_temp)

        # extract exclusion & events
        exclusions_df = pd.DataFrame.sparse.from_spmatrix(dataset.exclusions, index=eids, columns=[f"{c}_prev" for c in label_cols]).rename_axis("eid")
        events_df = pd.DataFrame.sparse.from_spmatrix(dataset.labels_events, index=eids, columns=[f"{c}_event" for c in label_cols]).rename_axis("eid")

        times = dataset.labels_times.todense()
        censorings = dataset.censorings

        no_event_idxs = times == 0
        times[no_event_idxs] = censorings[:, None].repeat(repeats=times.shape[1], axis=1)[no_event_idxs]

        times_df = pd.DataFrame(data=times, index=eids, columns=[f"{c}_time" for c in label_cols]).rename_axis("eid")

        outcomes_temp = pd.concat([exclusions_df, events_df, times_df], axis=1)
        outcomes_list.append(outcomes_temp)
        
    records_df = pd.concat(records_list, axis=0)
    outcomes_df = pd.concat(outcomes_list, axis=0)
        
    return records_df, outcomes_df

In [None]:
records_df, outcomes_df = extract_records_events_times(datamodule)

## Write Records

In [None]:
records_df.info()

In [None]:
for c in tqdm(records_df.columns):
    records_df[c] = records_df[c].astype(bool).sparse.to_dense()

In [None]:
records_df = records_df.sort_index()

In [None]:
records_df.info()

In [None]:
records_df.reset_index().to_feather(f"{output_path}/baseline_records_220627.feather")

## Records long

In [None]:
record_ids = sorted([r for r in records_df.columns.unique().tolist() if not "phecode" in r])
records_long = pd.DataFrame()

records_df_list = []
for r in tqdm(record_ids):
    temp = records_df[[r]].assign(record = r).query(r)
    temp.columns = ["record", "concept"]
    records_df_list.append(temp)
    
records_long = pd.concat(records_df_list, axis=0)[["concept", "record"]].assign(concept = lambda x: x.concept.astype("category")).reset_index()

In [None]:
records_long.info()

In [None]:
records_long.to_feather(f"{output_path}/baseline_records_long_220627.feather")

## Write Outcomes

In [None]:
for c in tqdm(outcomes_df.columns):
    if c.endswith("_prev") or c.endswith("_event"):
        outcomes_df[c] = outcomes_df[c].astype(bool).sparse.to_dense()
    if c.endswith("_time"):
        outcomes_df[c] = outcomes_df[c].astype(np.float32)

In [None]:
outcomes_df = outcomes_df.sort_index()

In [None]:
outcomes_df.info()

In [None]:
outcomes_df.reset_index().to_feather(f"{output_path}/baseline_outcomes_220627.feather")

### Outcomes long

In [None]:
endpoints = sorted(outcomes_df.columns.str.replace("_prev|_event|_time", "", regex=True).unique().tolist())

In [None]:
outcomes_long = pd.DataFrame()

In [None]:
outcomes_df_list = []
cols = ["prev", "event", "time"]
for e in tqdm(endpoints):
    temp = outcomes_df[[f"{e}_{c}" for c in cols]].assign(endpoint = e)
    temp.columns = cols + ["endpoint"]
    outcomes_df_list.append(temp)

In [None]:
outcomes_long = pd.concat(outcomes_df_list, axis=0)[["endpoint"] + cols].assign(endpoint = lambda x: x.endpoint.astype("category")).reset_index()

In [None]:
outcomes_long.info()

In [None]:
outcomes_long.to_feather(f"{output_path}/baseline_outcomes_long_220627.feather")