## Initialize

In [None]:
#library(Rmisc)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
library("ggbeeswarm")

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

dataset_name = "210714_metabolomics"
path = "/data/analysis/ag-reils/steinfej/code/umbrella/pre/ukbb"
data_path = glue("{base_path}/data")
dataset_path = glue("{data_path}/3_datasets_post/{dataset_name}")

project_label="21_metabolomics_multitask"
project_path = glue("{base_path}/results/projects/{project_label}")
figures_path = glue("{project_path}/figures")
data_results_path = glue("{project_path}/data")

## Load data

In [None]:
list.dirs(path = project_path, full.names = TRUE, recursive = TRUE)

In [None]:
run = "211007"
data = arrow::read_feather(glue("{dataset_path}/data_merged.feather")) 
data_description = arrow::read_feather(glue("{dataset_path}/description_merged.feather"))

predictions = arrow::read_feather(glue("{data_results_path}/predictions_{run}_metabolomics.feather")) 
loghazards = arrow::read_feather(glue("{data_results_path}/loghazards_model_{run}_metabolomics.feather")) %>% pivot_longer(starts_with("logh"), names_to=c("endpoint", "features"), values_to="logh", names_pattern="logh_?(.*)_(.*)$")

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
library(ggplot2); 
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
logh_NMR = loghazards %>% filter(split=="test") %>% left_join(data %>% select(eid, starts_with("NMR_"), -c(`NMR_measurement_quality_flagged`, `NMR_spectrometer`)) %>% filter(NMR_FLAG==TRUE), by="eid")
logh_NMR_long = logh_NMR %>% pivot_longer(starts_with("NMR_"), names_to="marker", values_to="value")

In [None]:
library(ggforestplot)

In [None]:
# Deepexplainer
attributions = arrow::read_feather(glue("{data_results_path}/attributions_211026.feather")) %>% mutate(explainer="DeepExplainer")

## Attributions by shap

In [None]:
library(ggthemes)
endpoint_map = c(
    'M_MACE'='MACE',
    'M_all_cause_dementia'='Dementia',
    'M_type_2_diabetes'='T2 Diabetes',
    'M_liver_disease'='Liver Disease',
    'M_renal_disease'='Renal Disease',
    'M_atrial_fibrillation'='Atrial Fibrillation',
    'M_heart_failure'= 'Heart Failure',
    'M_coronary_heart_disease'='CHD',
    'M_venous_thrombosis'='Ven. Thrombosis',
    'M_cerebral_stroke'='Cerebral Stroke',
    'M_abdominal_aortic_aneurysm'='AAA',
    'M_peripheral_arterial_disease'='PAD',
    "M_chronic_obstructuve_pulmonary_disease" = "COPD",
    "M_asthma" = "Asthma",
    'M_parkinsons_disease' = "Parkinson's",    
    "M_lung_cancer" = "Lung Cancer",
    "M_non_melanoma_skin_cancer" = "Skin Cancer",
    "M_colon_cancer"= "Colon Cancer",
    "M_rectal_cancer" = "Rectal Cancer",
    "M_prostate_cancer"= "Prostate Cancer",
    "M_breast_cancer" = "Breast Cancer",
    'M_cataracts' = "Cataracts", 
    'M_glaucoma' = "Glaucoma",
    'M_fractures' = "Fractures"
)

endpoint_order = c("M_MACE", "M_coronary_heart_disease", "M_cerebral_stroke", "M_all_cause_dementia", "M_heart_failure", "M_atrial_fibrillation",
                   "M_type_2_diabetes", "M_liver_disease", "M_renal_disease", "M_peripheral_arterial_disease", "M_venous_thrombosis",  "M_abdominal_aortic_aneurysm",
                   "M_chronic_obstructuve_pulmonary_disease", "M_asthma", 'M_parkinsons_disease', 'M_cataracts', 'M_glaucoma', 'M_fractures',
                    "M_lung_cancer","M_non_melanoma_skin_cancer","M_colon_cancer","M_rectal_cancer","M_prostate_cancer","M_breast_cancer"
                   
)

In [None]:
library(ggforestplot)
ng_names = df_NG_biomarker_metadata %>% mutate(metabolite = str_replace_all(tolower(description), " ", "_"))
ng_names %>% sample_n(10)

In [None]:
library(fuzzyjoin)

In [None]:
library(fuzzyjoin)
mets1 = attributions %>% select(metabolite) %>% distinct() %>% left_join(ng_names, by = "metabolite")
mets2 = mets1 %>% filter(is.na(name)) %>% select(metabolite) %>% stringdist_left_join(ng_names, by = "metabolite", max_dist = 1) %>% 
    rename(metabolite = metabolite.x) %>% select(-metabolite.y) %>% distinct()
mets3 = mets2 %>% filter(is.na(name)) %>% select(metabolite) %>% stringdist_left_join(ng_names, by = "metabolite", max_dist = 8) %>% 
    rename(metabolite = metabolite.x) %>% select(-metabolite.y) %>% distinct()
mets = bind_rows(mets1 %>% filter(!is.na(name)), mets2 %>% filter(!is.na(name)), mets3)
mets %>% sample_n(5)

In [None]:
attributions_metadata = attributions %>% left_join(mets %>% select(metabolite, abbreviation, group, subgroup), by="metabolite") %>% mutate(eid=as.integer(as.character(eid)))

In [None]:
library(gghighlight)

In [None]:
nmr_real = data %>% select(eid, starts_with("NMR_"), -`NMR_measurement_quality_flagged`, -`NMR_spectrometer`) %>% 
    filter(NMR_FLAG==TRUE) %>% pivot_longer(contains("NMR_"), names_to="metabolite", values_to="met_real") %>% 
    mutate(metabolite = str_remove_all(metabolite, "NMR_"))

In [None]:
prev_events = data %>% select(eid, starts_with("M_"), -ends_with("_event"), -ends_with("_time")) %>% 
    pivot_longer(contains("M_"), names_to="endpoint", values_to="event") %>% distinct()#%>% 
    #mutate(metabolite = str_remove_all(metabolite, "NMR_"))
prev_events %>% head()

In [None]:
clean_label = function(label){return(stringr::str_wrap(str_replace_all(label, "_", " "), 20))}

In [None]:
hrs = loghazards %>% filter(features=="Metabolomics") %>% mutate(hr = exp(logh)) %>% filter(split=="test") %>% select(eid, endpoint, hr)

## Global attributions

In [None]:
#n_eids = 10000
#eids = (attributions_metadata %>% select(eid) %>% distinct() %>% sample_n(n_eids))$eid
met_order_df = attributions_metadata %>% select(group, subgroup, metabolite, abbreviation) %>% distinct() %>% arrange(group, subgroup, abbreviation) %>% mutate(group_id = as.integer(factor(group)))
met_order = met_order_df$metabolite
abbrev_order = met_order_df$abbreviation
group_order = (met_order_df %>% select(group) %>% distinct())$group
#subgroup_order = (met_order_df %>% select(group, subgroup) %>% distinct())$subgroup
attrib_raw = attributions_metadata %>% #filter(eid %in% eids)  %>% 
    left_join(nmr_real, by=c("eid", "metabolite")) %>% 
    left_join(hrs, by=c("eid", "endpoint")) %>% 
    left_join(prev_events, by=c("eid", "endpoint")) %>% 
    ungroup() %>% mutate(metabolite=factor(metabolite, levels=met_order)) %>%
    mutate(abbreviation=factor(abbreviation, levels=abbrev_order))#%>% mutate(shap=raster::clamp(shap, -2, +2))

In [None]:
subgroup_order = c( 'Amino acids',
                    'Branched-chain amino acids',
                   'Aromatic amino acids',
                   'Fluid balance',
                   'Inflammation',
                    'Fatty acids',
                    'Glycolysis related metabolites',
                    'Ketone bodies',
         
                   'Total lipids',
                    'Cholesterol',
                    'Free cholesterol',
                   'Cholesteryl esters',
                   'Phospholipids',
                   'Triglycerides',
                   'Other lipids',
                   
                    'Lipoprotein particle sizes',
                    'Lipoprotein particle concentrations',
                    'Chylomicrons and extremely large VLDL',
                   'Very large VLDL',
                   'Large VLDL',
                   'Medium VLDL',
                   'Small VLDL',
                   'Very small VLDL',
              
                   'Large LDL',
                   'Medium LDL',
                   'Small LDL',
                    'IDL',
                   'Very large HDL',
                   'Large HDL',
                   'Medium HDL',
                   'Small HDL',
                   'Apolipoproteins'
                  )

In [None]:
attrib_sample = attrib_raw %>% group_by(endpoint, metabolite, explainer) %>% 
    mutate(shap_quantile=ntile(shap, 100), met_quantile=ntile(met_value, 100))

In [None]:
attrib_sample_mean = attrib_sample %>% ungroup() %>% 
    group_by(endpoint, metabolite, abbreviation, group, subgroup, explainer, shap_quantile) %>% 
    summarise(met_quantile=mean(met_quantile), mean_shap = mean(shap), mean_met=mean(met_value)) 

## Specific Endpoints

In [None]:
suppressPackageStartupMessages(library(circlize))

In [None]:
library(purrr)

In [None]:
library(plotly)

In [None]:
library(scales)

In [None]:
endpoint = "M_all_cause_dementia"

In [None]:
add.alpha <- function(col, alpha=1){
  if(missing(col))
    stop("Please provide a vector of colours.")
  apply(sapply(col, col2rgb)/255, 2, 
                     function(x) 
                       rgb(x[1], x[2], x[3], alpha=alpha))  
}

In [None]:
library(scales)

In [None]:
sectors = attrib_sample_mean %>% ungroup() %>% select(subgroup, metabolite) %>% 
    distinct() %>% group_by(subgroup) %>% tally() %>% 
    mutate(subgroup=factor(subgroup, levels=subgroup_order)) %>% arrange(subgroup) %>% mutate(x1=0, x2=n+1)# %>% purrr::discard(~ .x %in% c("Inflammation")))) %>%

In [None]:
options(repr.plot.width = 12, repr.plot.height = 12, repr.plot.res=320)
library(circlize)
library(scales)
endpoint = "M_all_cause_dementia"

col_fun = colorRamp2(c(-1, 0, 1), c("blue", "white", "red"))

plot_endpoint = function(endpoint){
    
    add = 0.2
    if (endpoint=="M_type_2_diabetes"){add=.2}
    if (endpoint=="M_all_cause_dementia"){add=.06}

    circos.clear()
    circos.par("track.height" = 0.2, cell.padding = c(0, 0, 0, 0), gap.degree = 0., track.margin=c(.001,.001), start.degree=90, clock.wise=T)

    temp_all = attrib_sample_mean %>% ungroup() %>%
      mutate(alpha=rescale(log1p(log1p(abs(mean_shap))), to=c(0, 1))) %>%
        filter(endpoint==!!endpoint) %>%
       #filter(subgroup!="Inflammation") %>%
      mutate(subgroup=factor(subgroup, levels=subgroup_order))# %>% purrr::discard(~ .x %in% c("Inflammation")))) %>%
       #filter(!is.na(subgroup))

    circos.initialize(sectors=levels(temp_all$subgroup), xlim=sectors %>% select(x1, x2))
    
    # Metabolite groups
    circos.par("track.height"=0.16)
    circos.track(temp_all$subgroup, y = temp_all$mean_shap)
    circos.trackPlotRegion(track.index = 1, panel.fun = function(x, y) {
        
        temp_label = temp_all %>% ungroup() %>% dplyr::select(subgroup, metabolite, abbreviation) %>% distinct() %>% 
            filter(subgroup == get.cell.meta.data("sector.index"))
        
        xlim = get.cell.meta.data("xlim")
        size = as.integer(xlim[2]-xlim[1]) +1
        
        if (size<4){sector.name = get.cell.meta.data("sector.index")}
        else{sector.name = str_wrap(get.cell.meta.data("sector.index"), 15)}
        
        circos.text(get.cell.meta.data("xcenter")+0.5, get.cell.meta.data("ylim")[1]+add, sector.name, facing = "clockwise", 
                 niceFacing = TRUE, adj = c(0, 0.5), cex=1, col = add.alpha("black", 0.7), font=2)
        
        temp_met = temp_all %>% ungroup() %>% filter(subgroup == get.cell.meta.data("sector.index"))
        x = seq(get.cell.meta.data("xlim")[1]+1,get.cell.meta.data("xlim")[2])
        circos.lines(x=x, y=rep(get.cell.meta.data("ylim")[1], length(x)), col=add.alpha("gray50", 0.5), pch = 20, cex = 0.5)
        #circos.text(mean(xlim), get.cell.meta.data("ylim")[1]+.15, sector.name, facing = "outside", 
            #     niceFacing = TRUE, adj = c(0.5, 1), cex=0.8, col = "black")
    
    }, bg.border ="white", bg.col="white")
    
    #rand_color(n, hue = NULL, luminosity = "random", transparency = 0)
    
    # Metabolite labels
    circos.par("track.height"=0.2)
    circos.track(temp_all$subgroup, y = temp_all$mean_shap)
circos.trackPlotRegion(track.index = 2, panel.fun = function(x, y) {
        
        temp_label = temp_all %>% ungroup() %>% dplyr::select(subgroup, metabolite, abbreviation) %>% distinct() %>% 
            filter(subgroup == get.cell.meta.data("sector.index"))
        
        xlim = get.cell.meta.data("xlim")
        size = as.integer(xlim[2]-xlim[1]) +1
        
        if (size<4){sector.name = get.cell.meta.data("sector.index")}
        else{sector.name = str_wrap(get.cell.meta.data("sector.index"), 15)}
        
        circos.text(x=seq(get.cell.meta.data("xlim")[1]+1,get.cell.meta.data("xlim")[2]), y=get.cell.meta.data("ylim")[1]+add, temp_label$abbreviation, facing = "clockwise", 
                 niceFacing = TRUE, adj = c(0, 0.5), cex=0.75, col = add.alpha("black", 0.7))
        
        temp_met = temp_all %>% ungroup() %>% filter(subgroup == get.cell.meta.data("sector.index"))
        x = seq(get.cell.meta.data("xlim")[1]+1,get.cell.meta.data("xlim")[2])
        circos.lines(x=x, y=rep(get.cell.meta.data("ylim")[1], length(x)), col=add.alpha("gray50", 0.5), pch = 20, cex = 0.5)
        #circos.text(mean(xlim), get.cell.meta.data("ylim")[1]+.15, sector.name, facing = "outside", 
            #     niceFacing = TRUE, adj = c(0.5, 1), cex=0.8, col = "black")
    
    }, bg.border = "white", bg.col="white")

    circos.par("track.height"=0.4)
    circos.track(temp_all$subgroup, y = temp_all$mean_shap, ylim = c(-1, +1.2))

    circos.trackPlotRegion(track.index = 3, panel.fun = function(x, y) {
        temp_met = temp_all %>% ungroup() %>% filter(subgroup == get.cell.meta.data("sector.index")) %>%
            mutate(color=col_fun(mean_met)) %>% 
            rowwise() %>%
            mutate(color_a=add.alpha(color, alpha=alpha))
            ##mutate(color_a=add.alpha(toRGB(color), alpha=alpha))mutate(color_a=add.alpha(toRGB(color), alpha=alpha))
            #mutate(color_a=GISTools::add.alpha(str_sub(color, start=1L, end=-3L), alpha=alpha)
        #print(y)
        #print(color)
        x = seq(get.cell.meta.data("xlim")[1],get.cell.meta.data("xlim")[2])
        #print(x)
        circos.lines(x=x, rep(get.cell.meta.data("ylim")[1], length(x)), col=add.alpha("gray50", 0.5), pch = 20, cex = 0.5)
    
        qs = unique(temp_met$shap_quantile)
        for (q in qs){
            temp_met_ind = temp_met %>% filter(shap_quantile==!!q)
            y = c(NA, (temp_met_ind %>% group_by(metabolite) %>% slice_head(n=1))$mean_shap, NA)
            color = c(NA, (temp_met_ind %>% group_by(metabolite) %>% slice_head(n=1))$color_a, NA)
            circos.points(jitter(x, 1), y=y, cex=0.8, pch=16, col = color)
            }

        circos.lines(x, rep(0.2, length(x)), col=add.alpha("firebrick", 0.5), pch = 20, cex = 0.5)
        circos.lines(x, rep(0, length(x)), col=add.alpha("grey50", 0.5), pch = 20, cex = 0.5)
        circos.lines(x, rep(-0.2, length(x)), col=add.alpha("forestgreen", 0.5), pch = 20, cex = 0.5)   
        

    }, bg.border = "white", bg.col="white")
    
text(0, 0, endpoint_map[[endpoint]], cex = 1.5, col=add.alpha("black", 0.7))
 
}

In [None]:
pdf("outputs/Figures_6_B_T2D.pdf", width=12, height=12)
plot_endpoint("M_type_2_diabetes")
dev.off()

In [None]:
pdf("outputs/Figures_6_C_Dementia.pdf", width=12, height=12)
plot_endpoint("M_all_cause_dementia")
dev.off()

In [None]:
for (endpoint in endpoint_order){
    endpoint_label = endpoint_map[[endpoint]]
    pdf(glue("outputs/GlobalAttributions/GlobalAttributions_{endpoint_label}.pdf"), width=12, height=12)
    plot_endpoint(endpoint)
    dev.off()
    }