## Initialize

In [None]:
#library(Rmisc)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
library(ggdist)
library(ggtext)
library(ggforestplot)
library(cowplot)

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

dataset_name = "210714_metabolomics"
path = "/data/analysis/ag-reils/steinfej/code/umbrella/pre/ukbb"
data_path = glue("{base_path}/data")
dataset_path = glue("{data_path}/3_datasets_post/{dataset_name}")

project_label="21_metabolomics_multitask"
project_path = glue("{base_path}/results/projects/{project_label}")
figures_path = glue("{project_path}/figures")
data_results_path = glue("{project_path}/data")

## Load data

In [None]:
list.dirs(path = project_path, full.names = TRUE, recursive = TRUE)

In [None]:
run = "211007"

In [None]:
DSM = "MultiTaskSurvivalTraining"
data = arrow::read_feather(glue("{dataset_path}/data_merged.feather")) 
data_description = arrow::read_feather(glue("{dataset_path}/description_merged.feather"))

In [None]:
library(ggalt)

In [None]:
library("jsonlite")
colors_path = "colors.json"
colors_dict = read_json(colors_path)

In [None]:
color_map <- c("all" = "grey", "none" = "black",
          
               "SCORE_SCORE2" = colors_dict$pastel$grey$light, 
               'SCORE_ASCVD' = colors_dict$pastel$grey$mid, 
    
               
               'COX_Age+Sex' = colors_dict$pastel$grey$light, 
               
               'PCA_Metabolomics' = "#4F8EC1",
               'COX_Metabolomics' = "#4F8EC1",
               'DS_Metabolomics' = "#4F8EC1",  
               
                'COX_SCORE2' = colors_dict$pastel$grey$light, 
               'COX_ASCVD' = colors_dict$pastel$grey$light, 
 
               'COX_PANEL' = colors_dict$pastel$grey$light, 
  
               'DS_Age+Sex+Metabolomics' = "#53dd6c",#colors_dict$pastel$orange$mid,
               'DS_SCORE2+Metabolomics' = colors_dict$pastel$red$mid,
               'DS_ASCVD+Metabolomics' = "#d8315b",#colors_dict$pastel$red$mid,

               'DS_PANEL+Metabolomics' = "#1e1b18" #colors_dict$pastel$red$dark
      )

In [None]:
scores_full = names(color_map)

In [None]:
name = glue("benchmark1000_cindex_subgroups_220106")
benchmark_cindex_sg = read_feather(glue("{data_results_path}/{name}.feather"))  %>% distinct() %>% unite("score", c(module, features), remove=FALSE) %>%  distinct()

In [None]:
base_size = 8
title_size = 8
facet_size = 8
geom_text_size=3
library(ggplot2); 
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=8), axis.text=element_text(size=8, color="black"), axis.text.x=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
library(ggthemes)
endpoint_map = c(
    'M_MACE'='MACE',
    'M_all_cause_dementia'='Dementia',
    'M_type_2_diabetes'='T2 Diabetes',
    'M_liver_disease'='Liver Disease',
    'M_renal_disease'='Renal Disease',
    'M_atrial_fibrillation'='Atrial Fibrillation',
    'M_heart_failure'= 'Heart Failure',
    'M_coronary_heart_disease'='CHD',
    'M_venous_thrombosis'='Ven. Thrombosis',
    'M_cerebral_stroke'='Cerebral Stroke',
    'M_abdominal_aortic_aneurysm'='AAA',
    'M_peripheral_arterial_disease'='PAD',
    "M_chronic_obstructuve_pulmonary_disease" = "COPD",
    "M_asthma" = "Asthma",
    'M_parkinsons_disease' = "Parkinson's",    
    "M_lung_cancer" = "Lung Cancer",
    "M_non_melanoma_skin_cancer" = "Skin Cancer",
    "M_colon_cancer"= "Colon Cancer",
    "M_rectal_cancer" = "Rectal Cancer",
    "M_prostate_cancer"= "Prostate Cancer",
    "M_breast_cancer" = "Breast Cancer",
    'M_cataracts' = "Cataracts", 
    'M_glaucoma' = "Glaucoma",
    'M_fractures' = "Fractures"
)

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"
                   
)

In [None]:
eids_included = c()
for (endpoint in endpoint_order){
    data_temp = data %>% filter(NMR_FLAG==TRUE)
    endpoint_label = glue("{endpoint}")
    eids = (data_temp %>% filter(!!sym(endpoint_label)==0))$eid
    if (endpoint=="M_MACE"){eids = (data_temp %>% filter(!!sym(endpoint_label)==0&statins==0))$eid} 
    if (endpoint=="M_breast_cancer"){eids = (data_temp %>% filter(!!sym(endpoint_label)==0&sex=="Female"))$eid} 
    if (endpoint=="M_prostate_cancer"){eids = (data_temp %>% filter(!!sym(endpoint_label)==0&sex=="Male"))$eid} 
    eids_included[[endpoint]] = eids

    }

In [None]:
em_wrap = endpoint_map
em_wrap[names(endpoint_map)] <- str_wrap(unname(endpoint_map), 20)
em_wrap

In [None]:
options(repr.plot.width = 8, repr.plot.height = 8)
library(ggbeeswarm)
scores_plot = names(color_map)

temp = benchmark_cindex_sg %>% 
    mutate_at(vars(score), list(~ factor(., levels=scores_full))) %>% filter(score==score) %>% filter(score %in% scores_plot)
temp_wide = temp %>% select(-module, -features) %>% pivot_wider(names_from="score", values_from="cindex") %>% mutate(delta=`DS_Metabolomics`-`COX_Age+Sex`)
temp_desc = temp %>% group_by(endpoint, score) %>% summarise(mean=median(cindex, na.rm=TRUE), max = max(cindex), .groups="drop")
temp_desc = temp_desc %>% select(-max) %>% pivot_wider(names_from="score", values_from="mean") %>% mutate(delta=`DS_Metabolomics`-`COX_Age+Sex`) %>% mutate(endpoint=fct_reorder(endpoint, desc(delta)))
endpoint_order_desc = levels(temp_desc$endpoint)

In [None]:
scores_map = c(
    "DS_Metabolomics" = "MET",
    "COX_Age+Sex" = "AgeSex",
    "DS_Age+Sex+Metabolomics" = "AgeSex+MET",
    "COX_ASCVD" = "ASCVD",
    "DS_ASCVD+Metabolomics" = "ASCVD+MET",
    "COX_PANEL" = "PANEL", 
    "DS_PANEL+Metabolomics" = "PANEL+MET"
               )

## Figure 3 - Performance in Context

In [None]:
library(scales)

In [None]:
plot_width=8.25; plot_height=5.5; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)

scores_plot = c(
    "DS_Metabolomics",
    "COX_Age+Sex",
    "DS_Age+Sex+Metabolomics",
    "COX_ASCVD",
    "DS_ASCVD+Metabolomics",
    "COX_PANEL", 
    "DS_PANEL+Metabolomics"
               )

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"     
)

In [None]:
library(tidybayes)

In [None]:
## check that cohort sums up!
temp %>% group_by(endpoint, score, group, iteration) %>% summarise(sum(len_sg)) %>% filter(iteration==0) %>% ungroup() %>% sample_n(3)

In [None]:
group_map = c(
    "age"="Age",
    "sex"="Biological\nSex",
    "ethnic_background"="Ethnic\nBackground")

In [None]:
endpoint_order %>% head(12)

## Get Subgroup information

In [None]:
data_prev = data %>% select(eid, all_of(endpoint_order)) %>% pivot_longer(all_of(endpoint_order), names_to="endpoint", values_to="prev")

In [None]:
data_event = data %>% select(eid, ends_with("_event")) %>% pivot_longer(ends_with("_event"), names_to="endpoint", values_to="event") %>% mutate(endpoint = str_replace_all(endpoint, "_event", ""))
data_event_time = data %>% select(eid, ends_with("_event_time")) %>% pivot_longer(ends_with("_event_time"), names_to="endpoint", values_to="event_time") %>% mutate(endpoint = str_replace_all(endpoint, "_event_time", ""))

In [None]:
data_n_event = data_event %>% left_join(data_prev, by=c("eid", "endpoint")) %>% left_join(data_event_time, by=c("eid", "endpoint")) %>% 
    mutate(event_10 = case_when(event==0 ~ 0, event==1&event_time>10 ~ 0, event==1&event_time<=10 ~1)) %>% 
    select(eid, endpoint, prev, event_10) %>% filter(endpoint %in% endpoint_order)

In [None]:
data_sgs = data %>% filter(NMR_FLAG==TRUE) %>% 
    mutate(age=case_when(age_at_recruitment<50 ~ "<50", age_at_recruitment>=50&age_at_recruitment<=60 ~ "50-60", age_at_recruitment>60 ~ ">60")) %>% 
    select(eid, age, sex, ethnic_background) %>%
    left_join(data_n_event, by="eid") %>% pivot_longer(c(age, sex, ethnic_background), names_to="group", values_to="subgroup") 

In [None]:
data_sgs_included = c()
for (endpoint in endpoint_order){
    eids_endpoint = eids_included[[endpoint]]
    data_sgs_included[[endpoint]] = data_sgs %>% filter(endpoint==!!endpoint) %>% filter(eid %in% eids_endpoint) %>% ungroup()
    }
data_sgs_included = bind_rows(data_sgs_included)

In [None]:
data_sgs_agg = data_sgs %>% group_by(endpoint, group, subgroup) %>% summarise(len_sg=n(), events_sg=sum(event_10)) %>% ungroup() %>% mutate(label = glue("n={events_sg}/{len_sg}")) %>%
    mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>%
    mutate(group = factor(group, levels=c('age','sex','ethnic_background'))) %>%
    mutate(subgroup = factor(subgroup, levels=c('Female', 'Male', '<50', '50-60', '>60', 'White', 'Asian', 'Black', 'Mixed', 'Chinese'))) %>% filter(!is.na(subgroup))
data_sgs_agg 

In [None]:
temp = benchmark_cindex_sg %>% 
   mutate_at(vars(score), list(~ factor(., levels=scores_plot))) %>% filter(score==score) %>% filter(score %in% scores_plot) %>% 
    mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>%
    mutate(group = factor(group, levels=c('age','sex','ethnic_background'))) %>%
    mutate(subgroup = factor(subgroup, levels=c('Female', 'Male', '<50', '50-60', '>60', 'White', 'Asian', 'Black', 'Mixed', 'Chinese'))) %>% 
    left_join(data_sgs_agg, on=c(endpoint, group, subgroup)) %>%
    filter(events_sg>=100)

In [None]:
plot_width=10; plot_height=25; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)

temp_1 = temp %>% filter(endpoint %in% (endpoint_order %>% head(12)))
temp_labels = temp_1 %>% filter(score=="DS_PANEL+Metabolomics") %>% group_by(endpoint, group, subgroup, events_sg, len_sg, label) %>% summarise(median_cindex = median(cindex))

met_1 = ggplot(temp_1, aes(x=score)) +
    labs(x=NULL, y="C-Index")+
    geom_label(data=temp_labels, mapping=aes(label=glue("{events_sg}/{len_sg}"), x=4, y=Inf), hjust=0.5, vjust=1, size=2.5, fill="grey90", alpha=0.7)+
    geom_violin(mapping = aes(y=cindex, color=score, fill=score), alpha=0.7)+

    scale_x_discrete(labels=scores_map)+
    scale_y_continuous(breaks=scales::extended_breaks())+
    scale_color_manual(values=color_map)+scale_fill_manual(values=color_map)+ 
    facet_grid(endpoint~group+subgroup, scales="free_y", labeller = labeller(endpoint = endpoint_map, group=group_map), switch="y", drop=TRUE) +
    theme(
        legend.position="none",
        axis.text.x = element_text(angle =90, hjust = 1, vjust=0.5),
        panel.spacing = unit(0.8, "lines"), 
        panel.grid.major = element_line(colour = "grey50", size=0.1),
        strip.text = element_text(size = 10),
    strip.placement = "outside")

met_1

In [None]:
plot_width=10; plot_height=25; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)

temp_2 = temp %>% filter(endpoint %in% (endpoint_order %>% tail(12)))
temp_labels = temp_2 %>% filter(score=="DS_PANEL+Metabolomics") %>% group_by(endpoint, group, subgroup, events_sg, len_sg, label) %>% summarise(median_cindex = median(cindex))

met_2 = ggplot(temp_2, aes(x=score)) +
    labs(x=NULL, y="C-Index")+#, title="Metabolomics (orange) vs. Age+Sex (grey) vs. AgeSexMetabolomic (red)")+
    geom_label(data=temp_labels, mapping=aes(label=glue("{events_sg}/{len_sg}"), x=4, y=Inf), hjust=0.5, vjust=1, size=2.5, fill="grey90", alpha=0.7)+
    geom_violin(mapping = aes(y=cindex, color=score, fill=score), alpha=0.7)+

    scale_x_discrete(labels=scores_map)+
    scale_y_continuous(breaks=scales::extended_breaks())+
    
    scale_color_manual(values=color_map)+scale_fill_manual(values=color_map)+
    #coord_flip() + 
    facet_grid(endpoint~group+subgroup, scales="free_y", labeller = labeller(endpoint = endpoint_map, group=group_map), switch="y", drop=TRUE) +
    theme(
        legend.position="none",
        axis.text.x = element_text(angle =90, hjust = 1, vjust=0.5),
        panel.spacing = unit(0.8, "lines"), 
        panel.grid.major = element_line(colour = "grey50", size=0.1),
        strip.text = element_text(size = 10),
    strip.placement = "outside")#, axis.text.x = element_text(size=5.5, hjust=0.5))

met_2

In [None]:
plot_width=15; plot_height=20; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)
fig_sgs = met_1|met_2
fig_sgs

In [None]:
library(gt)
plot_name = "Suppl_Figures_5_SubgroupPerformance"
fig_sgs %>% ggsave(filename=glue("/home/steinfej/code/21_metabolomics_analysis/Round1/Figures/outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_dpi)