# Benchmarks

## Initialize

In [None]:
#library(Rmisc)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

dataset_name = "210714_metabolomics"
path = "/data/analysis/ag-reils/steinfej/code/umbrella/pre/ukbb"
data_path = glue("{base_path}/data")
dataset_path = glue("{data_path}/3_datasets_post/{dataset_name}")

project_label="21_metabolomics_multitask"
project_path = glue("{base_path}/results/projects/{project_label}")
figures_path = glue("{project_path}/figures")
data_results_path = glue("{project_path}/data")

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
library(ggplot2); 
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
library("jsonlite")
colors_dict = read_json("colors.json")

In [None]:
color_map <- c("all" = "grey", "none" = "black",
               
               'COX_Age+Sex' = colors_dict$pastel$red$mid, 
               'COX_Metabolomics' = colors_dict$pastel$blue$light,
               
               'DS_Metabolomics' = colors_dict$pastel$blue$mid,
               'DS_Age+Sex+Metabolomics' = colors_dict$pastel$green$mid,
                'DS_AgeSexMetabolomics' = colors_dict$pastel$green$dark
      )

In [None]:
library(data.table)

In [None]:
library(ggthemes)
endpoint_map = c(
    'M_MACE'='MACE',
    'M_all_cause_dementia'='Dementia',
    'M_type_2_diabetes'='T2 Diabetes',
    'M_liver_disease'='Liver Disease',
    'M_renal_disease'='Renal Disease',
    'M_atrial_fibrillation'='Atrial Fibrillation',
    'M_heart_failure'= 'Heart Failure',
    'M_coronary_heart_disease'='CHD',
    'M_venous_thrombosis'='Ven. Thrombosis',
    'M_cerebral_stroke'='Cerebral Stroke',
    'M_abdominal_aortic_aneurysm'='AAA',
    'M_peripheral_arterial_disease'='PAD',
    "M_chronic_obstructuve_pulmonary_disease" = "COPD",
    "M_asthma" = "Asthma",
    'M_parkinsons_disease' = "Parkinson's",    
    "M_lung_cancer" = "Lung Cancer",
    "M_non_melanoma_skin_cancer" = "Skin Cancer",
    "M_colon_cancer"= "Colon Cancer",
    "M_rectal_cancer" = "Rectal Cancer",
    "M_prostate_cancer"= "Prostate Cancer",
    "M_breast_cancer" = "Breast Cancer",
    'M_cataracts' = "Cataracts", 
    'M_glaucoma' = "Glaucoma",
    'M_fractures' = "Fractures"
)

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"
                   
)

In [None]:
endpoint_selection = c("M_MACE", 'M_coronary_heart_disease', 'M_cerebral_stroke', "M_all_cause_dementia", "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_venous_thrombosis", "M_asthma", "M_chronic_obstructuve_pulmonary_disease", 'M_parkinsons_disease', 'M_cataracts')

## Load data

In [None]:
list.dirs(path = project_path, full.names = TRUE, recursive = TRUE)

In [None]:
run = "211007"
DSM = "MultiTaskSurvivalTraining"
data = arrow::read_feather(glue("{dataset_path}/data_merged.feather")) 
data_description = arrow::read_feather(glue("{dataset_path}/description_merged.feather"))
predictions = arrow::read_feather(glue("{data_results_path}/predictions_{run}_metabolomics.feather")) 
loghazards = arrow::read_feather(glue("{data_results_path}/loghazards_model_{run}_metabolomics.feather")) %>% pivot_longer(starts_with("logh"), names_to=c("endpoint", "features"), values_to="logh", names_pattern="logh_?(.*)_(.*)$")

In [None]:
colnames(data %>% select(starts_with("M_"), -contains("_event")))

In [None]:
data_events = data %>% select(eid, ends_with("event"), ends_with("event_time")) %>% 
    pivot_longer(-eid, names_to=c("endpoint", "type"), values_to="value", names_pattern="(.*)(event_time|event)") %>% 
    mutate(endpoint = stringr::str_sub(endpoint, end=-2)) %>% pivot_wider(names_from="type", values_from="value")

In [None]:
loghazards_tte = loghazards %>% left_join(data_events, by=c("endpoint", "eid"))
predictions_tte = predictions %>% left_join(data_events, by=c("endpoint", "eid"))

In [None]:
head(predictions_tte)

In [None]:
loghazards_tte %>% colnames() 

In [None]:
logh_T_raw = loghazards_tte %>% filter(split=="test") %>% mutate(hr=exp(logh)) %>% select(-module, -datamodule)

In [None]:
prev_data = data %>% select(eid, all_of(names(endpoint_map))) %>% pivot_longer(-eid, names_to="endpoint", values_to="Prevalent") %>% mutate(Prevalent = as.integer(Prevalent))

In [None]:
# get information on prevalent disease
prev_data = data %>% select(eid, all_of(names(endpoint_map))) %>% 
    pivot_longer(-eid, names_to="endpoint", values_to="Prevalent") %>% 
    mutate(Prevalent = as.integer(Prevalent))

# calculate met percentile for complete and for excluded set
logh_T_inc = logh_T_raw %>% left_join(prev_data, by=c("eid", "endpoint")) %>% 
    filter(Prevalent==0) %>% group_by(endpoint, features) %>% mutate(logh_perc = ntile(logh, 100))
logh_T_all = logh_T_raw %>% left_join(prev_data, by=c("eid", "endpoint")) %>% 
    group_by(endpoint, features) %>% mutate(logh_perc = ntile(logh, 100)) %>% ungroup()

# Figure 2: Selected Endpoints

## Metabolic state and incident disease

In [None]:
logh_T_agg = logh_T_inc %>% group_by(endpoint, features) %>% mutate(logh_perc = ntile(logh, 100))%>% group_by(endpoint, features, logh_perc) %>% summarise(ratio = mean(event))
labels = logh_T_inc %>% group_by(endpoint, event) %>% summarise(median_logh = mean(logh)) %>% pivot_wider(names_from="event", values_from="median_logh") %>% mutate(delta = `1`-`0`) %>% arrange(desc(delta))

In [None]:
options(repr.plot.width = 8, repr.plot.height = 5, repr.plot.res=320)
features = "Metabolomics"
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
temp = logh_T_agg %>% filter(features=="Metabolomics") %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% ungroup()

met_events = ggplot(temp, aes(x=logh_perc, y=ratio*100, color=logh_perc)) + 
    labs(title=NULL, x="Metabolomics State Percentile [%]", y="Observed Event Rate [%]") +
    geom_point(alpha=0.7, size=0.1) + 
    scale_colour_gradient(
  low = "#7AC6FF",
  high = "#023768",
  space = "Lab",
  na.value = "grey50",
  guide = "colourbar",
  aesthetics = "colour"
)+
    #scale_color_manual(values=c("Metabolomics"="black"))+
    scale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=6) + theme(legend.position="none")
met_events 

### Bootstrapping

In [None]:
## bootstrapped CIs for percentiles
df_times = logh_T_inc %>% ungroup() %>% select(eid, endpoint, logh) %>% pivot_wider(names_from=endpoint, values_from=logh)

datalist = list()

for (i in 1:1000) {
    # ... make some data
    dat <- df_times %>% sample_frac(replace=TRUE)
    dat$i <- i  # maybe you want to keep track of which iteration produced it?
    datalist[[i]] <- dat # add it to your list
}

df_boot = dplyr::bind_rows(datalist) %>% pivot_longer(-c(eid, i), names_to="endpoint", values_to="logh") %>% 
    filter(!is.na(logh)) %>% left_join(logh_T_inc %>% select(eid, endpoint, event), by=c("eid", "endpoint"))

In [None]:
logh_T_all_bs = df_boot %>% 
    group_by(endpoint, i) %>% 
    mutate(logh_perc = ntile(logh, 10)) %>% 
    ungroup()

In [None]:
temp_bs = logh_T_all_bs %>% #sample_n(1000000) %>%
    group_by(endpoint, i, logh_perc) %>% 
    summarise(rate=mean(event))

In [None]:
temp_bs_finished = temp_bs %>% filter(logh_perc %in% c(1, 10)) %>% ungroup() %>% 
    #mutate(ratio = case_when(is.na(ratio) ~ 0.0, TRUE ~ ratio)) %>% 
    select(endpoint, logh_perc, rate, i) %>% pivot_wider(names_from=logh_perc, values_from=rate) %>% 
    select(endpoint, i, `1`, `10`) %>% mutate(ratio = `10`/`1`) %>% 
    group_by(endpoint) %>% summarise(
        ratio = quantile(ratio, c(0.025, 0.5, 0.975), na.rm=TRUE), 
        `1` = quantile(`1`, c(0.025, 0.5, 0.975), na.rm=TRUE),
        `10` = quantile(`10`, c(0.025, 0.5, 0.975), na.rm=TRUE),
        probs = c("CI025", "Median", "CI975")#,
        #`1` = quantile(`1`, c(0.25, 0.5, 0.75), na.rm=TRUE), prob = c("CI025", "Median", "CI975")
        ) %>% mutate(`1` = `1`*100, `10` = `10`*100)%>% ungroup() %>% 
    pivot_longer(c(ratio, `1`, `10`), names_to="type", values_to="value") %>% 
    pivot_wider(names_from=probs, values_from=value) %>% 
    mutate(string = case_when(type=="ratio" ~ glue("{round(Median, 2)} ({round(CI025, 2)}, {round(CI975, 2)})"),
                             type !="ratio" ~ glue("{format(round(Median, 2), nsmall=2)}% ({format(round(CI025, 2), nsmall=2)}%, {format(round(CI975, 2), nsmall=2)}%)"))) %>%
    select(endpoint, type, string) %>% pivot_wider(names_from=type, values_from=string) %>% 
    select(endpoint, `1`, `10`, ratio) %>% 
    mutate(endpoint = factor(endpoint, levels = endpoint_order)) %>% arrange(endpoint) %>%
    mutate(endpoint = recode(endpoint, !!!endpoint_map))

In [None]:
library(gt)
temp_bs_finished %>% gt(rowname_col="endpoint") %>% 
    tab_stubhead(label = "Endpoint") %>%
    tab_header(
        title = "Metabolomic State and Incident Disease in UK Biobank"
    ) %>% 
  cols_label(
    `1` = md("Bottom 10%"),
    `10` = md("Top 10%"),
      `ratio` = md("OR")
  ) %>%
 #  tab_style(
 #   style = list(
 #     cell_text(weight = "bold")
#      ),
#    locations = cells_body(
#      columns = ratio
 #   )) %>%
    cols_align(
    align = "right",
    columns = c(`1`, `10`, ratio)) %>%
 #   locations = cells_stub()
 # ) %>%
gtsave("outputs/Suppl_Table5_EventRateRatios.html")

## Metabolomic State and Event Trajectories

In [None]:
logh_T_metabolomics = logh_T_inc %>% select(endpoint, features, eid, logh_perc, event, event_time) %>% group_by(endpoint, features) %>% mutate(MET=case_when(logh_perc %in% 91:100 ~ "High", 
                                                                                   logh_perc %in% 45:55 ~ "Mid", 
                                                                                   logh_perc %in% 1:10 ~ "Low", TRUE ~ "NA")) %>% mutate(MET = fct_rev(factor(MET, levels=c("Low", "Mid", "High")))) %>% ungroup() %>% 
    filter(features=="Metabolomics", MET!="NA") #%>% select(eid, endpoint, logh, logh_group)

In [None]:
require("ggquickeda")
options(repr.plot.width = 8, repr.plot.height = 5, repr.plot.res=320)
met_map = c("High"="#023768", "Mid"="#4F8EC1", "Low"="#7AC6FF")
km_plot = ggplot(logh_T_metabolomics %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)), aes(time = event_time, status = event, fill=MET, color=MET,group=MET)) +
    geom_km(trans = "event") + 
    geom_kmticks(trans = "event", size=0.3) + geom_kmband(trans = "event") + 
    labs(x="Time [Years]", y="Cumulative Events [%]")+
    scale_color_manual(values=met_map)+scale_fill_manual(values=met_map)+
    scale_y_continuous(labels = function(x) round(x*100, 1), expand=c(0, 0))+
    scale_x_continuous(expand=expansion(add=c(0, .1)), breaks=c(5, 10))+
    facet_wrap(~ endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=6) + theme(legend.position="none")

# Figure 2

In [None]:
plot_width=8.25; plot_height=10; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)
fig2 = met_events / km_plot# +plot_annotation(tag_levels = 'A')

In [None]:
fig2

In [None]:
library(gt)
plot_name = "Figures_2_AB"
fig2 %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=320)