In [None]:
import numpy as np
import pandas as pd
import pathlib
from tqdm.auto import tqdm

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import ray

In [None]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

experiment = 230425
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

In [None]:
ray.shutdown()

In [None]:
ray.init(num_cpus=24, include_dashboard=False)

In [None]:
import wandb
api = wandb.Api(timeout=60)
entity, project = "cardiors", "recordgraphs"  # set to your entity and project 
tag = "revision_230425"
runs = api.runs(entity + "/" + project, filters={"tags": {"$in": [tag]}}) 

In [None]:
run_list = []
for run in tqdm(runs):
    run_list.append(
        {
            "id": run.path[-1], 
            "name": run.name,
            "tags": run.tags,
            "config": {k: v for k,v in run.config.items() if not k.startswith('_')},
            "summary": run.summary._json_dict,
            "path": None if "best_checkpoint" not in run.config.keys() else str(pathlib.Path(run.config["best_checkpoint"]).parent.parent)
        }
    )

In [None]:
runs_df = pd.DataFrame(run_list).query("path==path")
print(runs_df.shape)
runs_df.sort_values("name")

In [None]:
runs_df.iloc[0].path

In [None]:
temp = pd.read_feather("/sc-projects/sc-proj-ukb-cvd/results/models/RecordGraphs/w0vw9p4h/predictions/predictions.feather")

In [None]:
temp

## Process Predictions

In [None]:
name_dict = {
    "CovariatesOnlyTraining_['age_at_recruitment_f21022_0_0', 'sex_f31_0_0']_None_MLPHead": "Identity(AgeSex)+MLP",
    "RecordsIdentityTraining_[]_None_MLPHead": "Identity(Records)+MLP",
     "RecordsIdentityTraining_[]_None_LinearHead": "Identity(Records)+Linear",
    "RecordsGraphTraining_[]_HeteroGNN_MLPHead": "GNN(Records)+MLP",
    "RecordsIdentityTraining_['age_at_recruitment_f21022_0_0', 'sex_f31_0_0']_None_MLPHead": "Identity(AgeSex+Records)+MLP",
    "RecordsGraphTraining_['age_at_recruitment_f21022_0_0', 'sex_f31_0_0']_HeteroGNN_MLPHead": "GNN(AgeSex+Records)+MLP"
}

In [None]:
id_vars = ["eid", "model", "partition", "split"]

In [None]:
out_path = f"{experiment_path}/loghs"
pathlib.Path(out_path).mkdir(parents=True, exist_ok=True)

In [None]:
@ray.remote
def prepare_predictions(in_path, out_path):
    
    temp = pd.read_feather(in_path).rename(columns={"index": "eid"}).reset_index(drop=True)
    temp["model"] = (temp.module.astype(str) + "_" + temp.covariate_cols.astype(str) + "_" + temp.encoder.astype(str) + "_" + temp["head"].astype(str)).astype("category")
    temp = temp.replace({"model":name_dict}).drop(columns=["module", "encoder", "head", "covariate_cols", "record_cols"])
    for c in id_vars: temp[c] = temp[c].astype("category")
    
    model = temp.model.unique()[0]
    partition = temp.partition.unique()[0]
    for split in ["train", "valid", "test"]:
        fp_out = f"{out_path}/{model}/{partition}"
        pathlib.Path(fp_out).mkdir(parents=True, exist_ok=True)
        temp.query("split==@split").reset_index(drop=True).to_feather(f"{fp_out}/{split}.feather")

In [None]:
for p in tqdm(runs_df.path): 
    prepare_predictions.remote(f"{p}/predictions/predictions.feather", out_path)

In [None]:
1+1