## Initialize

In [None]:
#library(Rmisc)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
library("ggbeeswarm")

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

dataset_name = "210714_metabolomics"
path = "/data/analysis/ag-reils/steinfej/code/umbrella/pre/ukbb"
data_path = glue("{base_path}/data")
dataset_path = glue("{data_path}/3_datasets_post/{dataset_name}")

project_label="21_metabolomics_multitask"
project_path = glue("{base_path}/results/projects/{project_label}")
figures_path = glue("{project_path}/figures")
data_results_path = glue("{project_path}/data")

## Load data

In [None]:
list.dirs(path = project_path, full.names = TRUE, recursive = TRUE)

In [None]:
run = "211007"
data = arrow::read_feather(glue("{dataset_path}/data_merged.feather")) 
data_description = arrow::read_feather(glue("{dataset_path}/description_merged.feather"))

predictions = arrow::read_feather(glue("{data_results_path}/predictions_{run}_metabolomics.feather")) 
loghazards = arrow::read_feather(glue("{data_results_path}/loghazards_model_{run}_metabolomics.feather")) %>% pivot_longer(starts_with("logh"), names_to=c("endpoint", "features"), values_to="logh", names_pattern="logh_?(.*)_(.*)$")

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
library(ggplot2); 
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
logh_NMR = loghazards %>% filter(split=="test") %>% left_join(data %>% select(eid, starts_with("NMR_"), -c(`NMR_measurement_quality_flagged`, `NMR_spectrometer`)) %>% filter(NMR_FLAG==TRUE), by="eid")
logh_NMR_long = logh_NMR %>% pivot_longer(starts_with("NMR_"), names_to="marker", values_to="value")
#corrs = logh_NMR_long %>% filter(marker!="NMR_FLAG") %>% group_by(endpoint, marker) %>% summarise(cor = cor(logh, value, use="complete.obs", method="pearson"))

In [None]:
library(ggforestplot)

In [None]:
# Deepexplainer
attributions = arrow::read_feather(glue("{data_results_path}/attributions_211026.feather")) %>% mutate(explainer="DeepExplainer")

## Attributions by shap

In [None]:
run="211007"
name = glue("benchmark_cindex_{run}")
benchmark_cindex_general = read_feather(glue("{data_results_path}/{name}.feather"))  %>% distinct() %>% unite("score", c(module, features), remove=FALSE) %>%  distinct()

In [None]:
library(ggdist)
perf_order = benchmark_cindex_general %>% filter(module=="DS", features=="Metabolomics") %>% group_by(endpoint) %>% median_qi(cindex) %>% arrange(desc(cindex))
endpoint_order_perf = perf_order$endpoint

In [None]:
library(ggthemes)
endpoint_map = c(
    'M_MACE'='MACE',
    'M_all_cause_dementia'='Dementia',
    'M_type_2_diabetes'='T2 Diabetes',
    'M_liver_disease'='Liver Disease',
    'M_renal_disease'='Renal Disease',
    'M_atrial_fibrillation'='Atrial Fibrillation',
    'M_heart_failure'= 'Heart Failure',
    'M_coronary_heart_disease'='CHD',
    'M_venous_thrombosis'='Ven. Thrombosis',
    'M_cerebral_stroke'='Cerebral Stroke',
    'M_abdominal_aortic_aneurysm'='AAA',
    'M_peripheral_arterial_disease'='PAD',
    "M_chronic_obstructuve_pulmonary_disease" = "COPD",
    "M_asthma" = "Asthma",
    'M_parkinsons_disease' = "Parkinson's",    
    "M_lung_cancer" = "Lung Cancer",
    "M_non_melanoma_skin_cancer" = "Skin Cancer",
    "M_colon_cancer"= "Colon Cancer",
    "M_rectal_cancer" = "Rectal Cancer",
    "M_prostate_cancer"= "Prostate Cancer",
    "M_breast_cancer" = "Breast Cancer",
    'M_cataracts' = "Cataracts", 
    'M_glaucoma' = "Glaucoma",
    'M_fractures' = "Fractures"
)

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"
                   
)

In [None]:
library(ggforestplot)
ng_names = df_NG_biomarker_metadata %>% mutate(metabolite = str_replace_all(tolower(description), " ", "_"))
ng_names %>% sample_n(10)

In [None]:
ng_names %>% select(group, subgroup) %>% distinct() %>% arrange(group, subgroup)

In [None]:
library(fuzzyjoin)

In [None]:
library(fuzzyjoin)
mets1 = attributions %>% select(metabolite) %>% distinct() %>% left_join(ng_names, by = "metabolite")
mets2 = mets1 %>% filter(is.na(name)) %>% select(metabolite) %>% stringdist_left_join(ng_names, by = "metabolite", max_dist = 1) %>% 
    rename(metabolite = metabolite.x) %>% select(-metabolite.y) %>% distinct()
mets3 = mets2 %>% filter(is.na(name)) %>% select(metabolite) %>% stringdist_left_join(ng_names, by = "metabolite", max_dist = 8) %>% 
    rename(metabolite = metabolite.x) %>% select(-metabolite.y) %>% distinct()
mets = bind_rows(mets1 %>% filter(!is.na(name)), mets2 %>% filter(!is.na(name)), mets3)
mets %>% sample_n(5)

In [None]:
attributions_metadata = attributions %>% left_join(mets %>% select(metabolite, abbreviation, group, subgroup), by="metabolite") %>% mutate(eid=as.integer(as.character(eid)))

In [None]:
library(gghighlight)

In [None]:
nmr_real = data %>% select(eid, starts_with("NMR_"), -`NMR_measurement_quality_flagged`, -`NMR_spectrometer`) %>% 
    filter(NMR_FLAG==TRUE) %>% pivot_longer(contains("NMR_"), names_to="metabolite", values_to="met_real") %>% 
    mutate(metabolite = str_remove_all(metabolite, "NMR_"))

In [None]:
prev_events = data %>% select(eid, starts_with("M_"), -ends_with("_event"), -ends_with("_time")) %>% 
    pivot_longer(contains("M_"), names_to="endpoint", values_to="event") %>% distinct()#%>% 
    #mutate(metabolite = str_remove_all(metabolite, "NMR_"))
prev_events %>% head()

In [None]:
clean_label = function(label){return(stringr::str_wrap(str_replace_all(label, "_", " "), 20))}

In [None]:
hrs = loghazards %>% filter(features=="Metabolomics") %>% mutate(hr = exp(logh)) %>% filter(split=="test") %>% select(eid, endpoint, hr)

## Global attributions

In [None]:
#n_eids = 10000
#eids = (attributions_metadata %>% select(eid) %>% distinct() %>% sample_n(n_eids))$eid
met_order_df = attributions_metadata %>% select(group, subgroup, metabolite, abbreviation) %>% distinct() %>% arrange(group, subgroup, abbreviation) %>% mutate(group_id = as.integer(factor(group)))
met_order = met_order_df$metabolite
abbrev_order = met_order_df$abbreviation
group_order = (met_order_df %>% select(group) %>% distinct())$group
#subgroup_order = (met_order_df %>% select(group, subgroup) %>% distinct())$subgroup
attrib_raw = attributions_metadata %>% #filter(eid %in% eids)  %>% 
    left_join(nmr_real, by=c("eid", "metabolite")) %>% 
    left_join(hrs, by=c("eid", "endpoint")) %>% 
    left_join(prev_events, by=c("eid", "endpoint")) %>% 
    ungroup() %>% mutate(metabolite=factor(metabolite, levels=met_order)) %>%
    mutate(abbreviation=factor(abbreviation, levels=abbrev_order))#%>% mutate(shap=raster::clamp(shap, -2, +2))

In [None]:
subgroup_order = c( 'Amino acids',
                    'Branched-chain amino acids',
                   'Aromatic amino acids',
                   'Fluid balance',
                   'Inflammation',
                    'Fatty acids',
                    'Glycolysis related metabolites',
                    'Ketone bodies',
         
                   'Total lipids',
                    'Cholesterol',
                    'Free cholesterol',
                   'Cholesteryl esters',
                   'Phospholipids',
                   'Triglycerides',
                   'Other lipids',
                   
                    'Lipoprotein particle sizes',
                    'Lipoprotein particle concentrations',
                    'Chylomicrons and extremely large VLDL',
                   'Very large VLDL',
                   'Large VLDL',
                   'Medium VLDL',
                   'Small VLDL',
                   'Very small VLDL',
              
                   'Large LDL',
                   'Medium LDL',
                   'Small LDL',
                    'IDL',
                   'Very large HDL',
                   'Large HDL',
                   'Medium HDL',
                   'Small HDL',
                   'Apolipoproteins'
                  )

In [None]:
attrib_sample = attrib_raw %>% group_by(endpoint, metabolite, explainer) %>% 
    mutate(shap_quantile=ntile(shap, 100), met_quantile=ntile(met_value, 100))

In [None]:
attrib_sample_mean = attrib_sample %>% ungroup() %>% 
    group_by(endpoint, metabolite, abbreviation, group, subgroup, explainer, shap_quantile) %>% 
    summarise(met_quantile=mean(met_quantile), mean_shap = mean(shap), mean_met=mean(met_value)) 

In [None]:
library(ggforce)

In [None]:
endpoint_selection = c("M_MACE", 
                       #'M_coronary_heart_disease', 
                       #'M_cerebral_stroke', 
                       "M_all_cause_dementia", 
                       "M_type_2_diabetes", 
                       "M_renal_disease",
                       "M_venous_thrombosis", 
                        #"M_chronic_obstructuve_pulmonary_disease", 
                       "M_asthma" 
                       #'M_parkinsons_disease', 
)

In [None]:
attrib_sample_mean = attrib_sample_mean %>% mutate(group_new = subgroup) %>% mutate(group_new=case_when(
    str_ends(abbreviation, "-P") ~ "Lipoprotein particle concentrations", 
    str_ends(abbreviation, "-L") ~ "Total lipids", 
    str_ends(abbreviation, "-C") ~ "Cholesterol", 
    str_ends(abbreviation, "-FC") ~ "Free cholesterol", 
    str_ends(abbreviation, "-CE") ~ "Cholesteryl esters", 
    str_ends(abbreviation, "-PL") ~ "Phospholipids", 
    str_ends(abbreviation, "-TG") ~ "Triglycerides", 
    TRUE ~ subgroup))

In [None]:
temp_global = attrib_sample %>% group_by(endpoint, subgroup, metabolite, abbreviation) %>% summarise(global_shap = sum(abs(shap)))

In [None]:
met_selection = (temp_global %>% group_by(metabolite) %>% summarise(mean_global = mean(global_shap, na.rm=T)) %>% arrange(desc(abs(mean_global))) %>% head(75))$metabolite

In [None]:
plot_width=3.25; plot_height=10; plot_dpi=320
options(repr.plot.width = 3.25, repr.plot.height = plot_height, repr.plot.res=320)
attr_delta = ggplot(temp_global %>% filter(metabolite %in% met_selection) %>% mutate(subgroup = factor(subgroup, levels=subgroup_order)), 
                    aes(x=factor(endpoint, levels=endpoint_order_perf), y=fct_rev(abbreviation), fill=abs(global_shap))) + 
    labs(x=NULL, y=NULL)+
    geom_tile()+theme(plot.title = element_text(vjust = - 15)) +
    scale_fill_gradient2(low = "darkblue",high = "#440154FF", midpoint = 0)+
    theme(legend.position = "bottom")+
    scale_x_discrete(labels=endpoint_map, position="top")+
    scale_y_discrete(position="left")+
    facet_grid(subgroup~., labeller=labeller(subgroup=label_wrap_gen(20)), scales="free", space="free")+ 
    theme(axis.text.x= element_text(size=6), axis.text.y= element_text(size=5.5), strip.text.y.right = element_text(angle = 0, size=6))+
    theme(axis.text.x.top= element_text(hjust=0, vjust=0.5)#, strip.text.y=element_blank()
         )+
    theme(strip.placement = 'outside') + 
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))+ theme(panel.spacing = unit(0.5, "lines"))
attr_delta 

In [None]:
library(gt)
plot_name = "Figures_6_A_AttributionHeatmap75"
attr_delta %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_dpi)

In [None]:
met_selection_top = (temp_global %>% ungroup() %>% select(metabolite, subgroup) %>% distinct() %>% mutate(subgroup = factor(subgroup, levels=subgroup_order)) %>% arrange(subgroup) %>% head(84))$metabolite

plot_width=4; plot_height=10; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=320)
attr_delta_full_left = ggplot(temp_global %>% filter(metabolite %in% met_selection_top) %>% mutate(subgroup = factor(subgroup, levels=subgroup_order)), 
                    aes(x=factor(endpoint, levels=endpoint_order_perf), y=fct_rev(abbreviation), fill=abs(global_shap))) + # %>% 
                    #filter(endpoint %in% c("M_type_2_diabetes", "M_all_cause_dementia")), 
    labs(x=NULL, y=NULL)+
    #geom_quasirandom(size=0.1) + 
    geom_tile()+theme(plot.title = element_text(vjust = - 15)) +
    scale_fill_gradient2(low = "darkblue",high = "#440154FF", midpoint = 0)+#, limits=c(-3, +3), oob=scales::squish) +

    theme(legend.position = "none")+#coord_flip()+# xlim(-1, 1.2)+#coord_flip()+#, panel.grid.major = element_blank())+#+
    scale_x_discrete(labels=endpoint_map, position="top")+
    scale_y_discrete(position="left")+

    facet_grid(subgroup~., labeller=labeller(subgroup=label_wrap_gen(25)), scales="free", space="free")+ 
    theme(axis.text.x= element_text(size=6), axis.text.y= element_text(size=6), strip.text.y.right = element_text(angle = 0, size=6))+
    theme(axis.text.x.top= element_text(hjust=0, vjust=0.5)#, strip.text.y=element_blank()
         )+
    theme(strip.placement = 'outside') + 
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))#+ theme(panel.spacing = unit(0.5, "lines"))
attr_delta_full_left #+ coord_polar()

In [None]:
plot_width=4; plot_height=10; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=320)
attr_delta_full_right = ggplot(temp_global %>% filter(!metabolite %in% met_selection_top) %>% mutate(subgroup = factor(subgroup, levels=subgroup_order)), 
                    aes(x=factor(endpoint, levels=endpoint_order_perf), y=fct_rev(abbreviation), fill=abs(global_shap))) + # %>% 
                    #filter(endpoint %in% c("M_type_2_diabetes", "M_all_cause_dementia")), 
    labs(x=NULL, y=NULL)+
    #geom_quasirandom(size=0.1) + 
    geom_tile()+theme(plot.title = element_text(vjust = - 15)) +

    scale_fill_gradient2(low = "darkblue",high = "#440154FF", midpoint = 0)+#, limits=c(-3, +3), oob=scales::squish) +

    theme(legend.position = "none")+#coord_flip()+# xlim(-1, 1.2)+#coord_flip()+#, panel.grid.major = element_blank())+#+
    scale_x_discrete(labels=endpoint_map, position="top")+
    scale_y_discrete(position="left")+

    facet_grid(subgroup~., labeller=labeller(subgroup=label_wrap_gen(25)), scales="free", space="free")+ 
    theme(axis.text.x= element_text(size=6), axis.text.y= element_text(size=6), strip.text.y.right = element_text(angle = 0, size=6))+
    theme(axis.text.x.top= element_text(hjust=0, vjust=0.5)#, strip.text.y=element_blank()
         )+
    theme(strip.placement = 'outside') + 
    theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust=1))#+ theme(panel.spacing = unit(0.5, "lines"))
attr_delta_full_right #+ coord_polar()

In [None]:
library(patchwork)

In [None]:
plot_width=8; plot_height=10; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=320)
attr_delta_full_final = (attr_delta_full_left | attr_delta_full_right)

In [None]:
library(gt)
plot_name = "Suppl_Figures_7_AttributionHeatmapFull"
attr_delta_full_final %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_dpi)