## Initialize

In [None]:
#library(Rmisc)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
library(ggdist)

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

dataset_name = "210714_metabolomics"
path = "/data/analysis/ag-reils/steinfej/code/umbrella/pre/ukbb"
data_path = glue("{base_path}/data")
dataset_path = glue("{data_path}/3_datasets_post/{dataset_name}")

project_label="21_metabolomics_multitask"
project_path = glue("{base_path}/results/projects/{project_label}")
figures_path = glue("{project_path}/figures")
data_results_path = glue("{project_path}/data")

## Load data

In [None]:
list.dirs(path = project_path, full.names = TRUE, recursive = TRUE)

In [None]:
run = "220104"

In [None]:
DSM = "MultiTaskSurvivalTraining"
data = arrow::read_feather(glue("{dataset_path}/data_merged.feather")) 
data_description = arrow::read_feather(glue("{dataset_path}/description_merged.feather"))
predictions = arrow::read_feather(glue("{data_results_path}/predictions_{run}_metabolomics.feather")) 
loghazards = arrow::read_feather(glue("{data_results_path}/loghazards_model_{run}_metabolomics.feather")) %>% 
    pivot_longer(starts_with("logh"), names_to=c("endpoint", "features"), values_to="logh", names_pattern="logh_?(.*)_(.*)$")

In [None]:
data_events = data %>% select(eid, ends_with("event"), ends_with("event_time")) %>% 
    pivot_longer(-eid, names_to=c("endpoint", "type"), values_to="value", names_pattern="(.*)(event_time|event)") %>% 
    mutate(endpoint = stringr::str_sub(endpoint, end=-2)) %>% pivot_wider(names_from="type", values_from="value")

In [None]:
loghazards_tte = loghazards %>% left_join(data_events, by=c("endpoint", "eid"))

In [None]:
logh_T = loghazards_tte %>% filter(split=="test") %>% mutate(hr=exp(logh))

In [None]:
logh_T = logh_T %>% group_by(endpoint, features) %>% mutate(logh_perc = ntile(logh, 100))
logh_T_agg = logh_T %>% group_by(endpoint, features) %>% mutate(logh_perc = ntile(logh, 100)) %>% group_by(endpoint, features, logh_perc) %>% summarise(ratio = mean(event))

In [None]:
labels = logh_T %>% group_by(endpoint, event) %>% summarise(median_logh = mean(logh)) %>% pivot_wider(names_from="event", values_from="median_logh") %>% mutate(delta = `1`-`0`) %>% arrange(desc(delta))

In [None]:
scores = c("DS_PANEL", 
                "COX_PANEL",
                  "DS_PANELmetabolitesOverlap", 
                "DS_Metabolomics")

In [None]:
library(ggalt)

In [None]:
library("jsonlite")
colors_path = "colors.json"
colors_dict = read_json(colors_path)

In [None]:
color_map <- c("all" = "grey", "none" = "black",
               'COX_Age+Sex' = colors_dict$pastel$grey$light, 
               
               'PCA_Metabolomics' = "#4F8EC1",
               'COX_Metabolomics' = "#4F8EC1",
               'DS_Metabolomics' = "#4F8EC1",  
                'COX_SCORE2' = colors_dict$pastel$grey$light, 
               'COX_ASCVD' = colors_dict$pastel$grey$light, 

               'COX_PANEL' = colors_dict$pastel$grey$light, 

               'DS_Age+Sex+Metabolomics' = "#53dd6c",#colors_dict$pastel$orange$mid,
               'DS_ASCVD+Metabolomics' = "#d8315b",#colors_dict$pastel$red$mid,
               'DS_PANEL+Metabolomics' = "#1e1b18" #colors_dict$pastel$red$dark
      )

In [None]:
scores_full = names(color_map)

In [None]:
name = glue("benchmark_cindex_{run}")
benchmark_cindex_general = read_feather(glue("{data_results_path}/{name}.feather"))  %>% distinct() %>% unite("score", c(module, features), remove=FALSE) %>%  distinct()

In [None]:
unique(benchmark_cindex_general$score)

In [None]:
base_size = 8
title_size = 8
facet_size = 8
geom_text_size=3
library(ggplot2); 
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=8), axis.text=element_text(size=8, color="black"), axis.text.x=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
library(ggthemes)
endpoint_map = c(
    'M_MACE'='MACE',
    'M_all_cause_dementia'='Dementia',
    'M_type_2_diabetes'='T2 Diabetes',
    'M_liver_disease'='Liver Disease',
    'M_renal_disease'='Renal Disease',
    'M_atrial_fibrillation'='Atrial Fibrillation',
    'M_heart_failure'= 'Heart Failure',
    'M_coronary_heart_disease'='CHD',
    'M_venous_thrombosis'='Ven. Thrombosis',
    'M_cerebral_stroke'='Cerebral Stroke',
    'M_abdominal_aortic_aneurysm'='AAA',
    'M_peripheral_arterial_disease'='PAD',
    "M_chronic_obstructuve_pulmonary_disease" = "COPD",
    "M_asthma" = "Asthma",
    'M_parkinsons_disease' = "Parkinson's",    
    "M_lung_cancer" = "Lung Cancer",
    "M_non_melanoma_skin_cancer" = "Skin Cancer",
    "M_colon_cancer"= "Colon Cancer",
    "M_rectal_cancer" = "Rectal Cancer",
    "M_prostate_cancer"= "Prostate Cancer",
    "M_breast_cancer" = "Breast Cancer",
    'M_cataracts' = "Cataracts", 
    'M_glaucoma' = "Glaucoma",
    'M_fractures' = "Fractures"
)

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"   
)

In [None]:
em_wrap = endpoint_map
em_wrap[names(endpoint_map)] <- str_wrap(unname(endpoint_map), 20)
em_wrap

In [None]:
scores_map = c(
    "DS_Metabolomics" = "MET(NMR)",
    "COX_Metabolomics" = "COX MET",
    "PCA_Metabolomics" = "PCA MET",
    "DS_PANEL" = "MSM PANEL",
    "DS_PANELmetabolites" = "MET(PANEL)",
    "DS_PANELmetabolitesOverlap" = "MET(PANEL)",
    "DS_Age+Sex+PANELmetabolitesOverlap" = "AgeSex+MET(PANEL)",
    "COX_Age+Sex" = "AgeSex",
    "DS_Age+Sex+Metabolomics" = "AgeSex+MET(NMR)",
    "COX_ASCVD" = "ASCVD",
    "DS_ASCVD+Metabolomics" = "ASCVD+MET",
    "COX_PANEL" = "COX PANEL", 
    "DS_PANEL+Metabolomics" = "PANEL+MET"
               )

In [None]:
options(repr.plot.width = 8, repr.plot.height = 8)
library(ggbeeswarm)
scores_plot = names(color_map)

temp = benchmark_cindex_general %>% 
    mutate_at(vars(score), list(~ factor(., levels=names(scores_map)))) %>% filter(score==score) %>% filter(score %in% names(scores_map))
temp_wide = temp %>% select(-module, -features) %>% pivot_wider(names_from="score", values_from="cindex") %>% mutate(delta=`DS_Metabolomics`-`COX_Age+Sex`)
temp_desc = temp %>% group_by(endpoint, score) %>% summarise(mean=median(cindex, na.rm=TRUE), max = max(cindex), .groups="drop")
temp_desc = temp_desc %>% select(-max) %>% pivot_wider(names_from="score", values_from="mean") %>% mutate(delta=`DS_Metabolomics`-`COX_Age+Sex`) %>% mutate(endpoint=fct_reorder(endpoint, desc(delta)))
endpoint_order_desc = levels(temp_desc$endpoint)

In [None]:
library(ggdist)

In [None]:
library(ggtext)

In [None]:
library(ggforestplot)

In [None]:
library(cowplot)

NMR vs. PANEL Metabolites

In [None]:
#MET(PANEL) = ['albumin', 'cholesterol', 'hdl_cholesterol', 'ldl_direct', 'triglycerides', 'glucose', 'creatinine']

In [None]:
library(scales)

In [None]:
scores_plot = c(
        "COX_Age+Sex",
    "DS_PANELmetabolitesOverlap",
    "DS_Metabolomics",
    "DS_Age+Sex+PANELmetabolitesOverlap",
    "DS_Age+Sex+Metabolomics"
               )

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"
                   
)


temp = benchmark_cindex_general %>% 
   mutate_at(vars(score), list(~ factor(., levels=scores_plot))) %>% filter(score==score) %>% filter(score %in% scores_plot) %>% 
    mutate(endpoint = factor(endpoint, levels=endpoint_order)) #%>% #filter(endpoint %in% endpoint_selection)
temp_desc = temp %>% group_by(endpoint, score) %>% summarise(mean=median(cindex, na.rm=TRUE), max = max(cindex), .groups="drop")

### Diff endpoints

In [None]:
ep_table  = temp %>% select(endpoint, features, iteration, cindex) %>% 
    group_by(endpoint, features) %>% median_qi(cindex) %>% mutate(result = glue("{round(cindex, 3)} ({round(.lower, 3)}, {round(.upper, 3)})")) %>% 
    select(endpoint, features, cindex, result)

In [None]:
agg_table = temp %>% group_by(features, iteration) %>% 
    summarise(cindex = mean(cindex)) %>% group_by(features) %>% 
    median_qi(cindex) %>% ungroup() %>% mutate(result = glue("{round(cindex, 3)} ({round(.lower, 3)}, {round(.upper, 3)})")) %>% mutate(endpoint = "Overall") %>% 
    select(endpoint, features, cindex, result)
#agg_table

In [None]:
perf_table = bind_rows(ep_table, agg_table) %>% select(-cindex) %>% pivot_wider(names_from="features", values_from="result")
perf_table$endpoint = recode(perf_table$endpoint, !!!endpoint_map)

In [None]:
plot_width=8.25; plot_height=5.5; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)

met_discrimination = ggplot(temp, aes(x=score)) +
    labs(x=NULL, y="Absolute C-Index")+
    stat_gradientinterval(mapping = aes(y=cindex, color=score, fill=score), alpha=0.7, fatten_point=0.8, interval_size_range=c(0.3, 0.6), slab_alpha=0)+
    scale_x_discrete(labels=scores_map)+
    scale_y_continuous(breaks=scales::extended_breaks())+
    scale_color_manual(values=color_map)+scale_fill_manual(values=color_map)+
    facet_wrap(~endpoint, scales="free_y", labeller = labeller(endpoint = endpoint_map), ncol=6) +theme(legend.position="none")+theme(axis.text.x = element_text(angle =90, vjust=0.5, hjust=1))+#, vjust=00))+#, hjust = 0.0))+#+#+
    theme(panel.spacing = unit(0.8, "lines"), panel.grid.major.y = element_line())#, axis.text.x = element_text(size=5.5, hjust=1))

met_discrimination

In [None]:
library(gt)
plot_name = "Suppl_Figures_2B_PANELMetabolites"
met_discrimination %>% ggsave(filename=glue("/home/steinfej/code/21_metabolomics_analysis/Round1/Figures/outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_dpi)

## PANEL predictors in multitask model

In [None]:
library(scales)

In [None]:
scores_plot = c(
    "COX_Age+Sex",
    "COX_PANEL",
    "DS_PANEL"
               )

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"
                   
)


temp = benchmark_cindex_general %>% 
   mutate_at(vars(score), list(~ factor(., levels=scores_plot))) %>% filter(score==score) %>% filter(score %in% scores_plot) %>% 
    mutate(endpoint = factor(endpoint, levels=endpoint_order)) #%>% #filter(endpoint %in% endpoint_selection)
temp_desc = temp %>% group_by(endpoint, score) %>% summarise(mean=median(cindex, na.rm=TRUE), max = max(cindex), .groups="drop")

### Diff endpoints

In [None]:
ep_table  = temp %>% select(endpoint, features, iteration, cindex) %>% 
    group_by(endpoint, features) %>% median_qi(cindex) %>% mutate(result = glue("{round(cindex, 3)} ({round(.lower, 3)}, {round(.upper, 3)})")) %>% 
    select(endpoint, features, cindex, result)

In [None]:
agg_table = temp %>% group_by(features, iteration) %>% 
    summarise(cindex = mean(cindex)) %>% group_by(features) %>% 
    median_qi(cindex) %>% ungroup() %>% mutate(result = glue("{round(cindex, 3)} ({round(.lower, 3)}, {round(.upper, 3)})")) %>% mutate(endpoint = "Overall") %>% 
    select(endpoint, features, cindex, result)
#agg_table

In [None]:
perf_table = bind_rows(ep_table, agg_table) %>% select(-cindex) %>% pivot_wider(names_from="features", values_from="result")
perf_table$endpoint = recode(perf_table$endpoint, !!!endpoint_map)

In [None]:
plot_width=8.25; plot_height=5.5; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)

met_discrimination = ggplot(temp, aes(x=score)) +
    labs(x=NULL, y="Absolute C-Index")+#, title="Metabolomics (orange) vs. Age+Sex (grey) vs. AgeSexMetabolomic (red)")+
    stat_gradientinterval(mapping = aes(y=cindex, color=score, fill=score), alpha=0.7, fatten_point=0.8, interval_size_range=c(0.3, 0.6), slab_alpha=0)+
    #geom_quasirandom(mapping = aes(y=`DS_AgeSexMetabolomics`-`COX_Age+Sex`), color=alpha(color_map[['DS_Age+Sex+Metabolomics']], 0.4), size=1)+
    scale_x_discrete(labels=scores_map)+
    scale_y_continuous(breaks=scales::extended_breaks())+#labels = scales::number_format(accuracy = 0.01,
                               #  decimal.mark = '.'))+#limits = quantile(temp$cindex, c(0.1, 0.9)))+#scale_y_continuous(guide = guide_axis(n.dodge=3))+#guide = guide_axis(check.overlap = TRUE), expand = expansion(mult = .05))+
    scale_color_manual(values=color_map)+scale_fill_manual(values=color_map)+
    #coord_flip() + 
    facet_wrap(~endpoint, scales="free_y", labeller = labeller(endpoint = endpoint_map), ncol=6) +theme(legend.position="none")+theme(axis.text.x = element_text(angle =90, vjust=0.5, hjust=1))+#, hjust = 0.0))+#+#+
    theme(panel.spacing = unit(0.8, "lines"), panel.grid.major.y = element_line())#, axis.text.x = element_text(size=5.5, hjust=1))

met_discrimination

In [None]:
library(gt)
plot_name = "Suppl_Figures_3_NNPanel"
met_discrimination %>% ggsave(filename=glue("/home/steinfej/code/21_metabolomics_analysis/Round1/Figures/outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_dpi)