# Benchmarks

## Initialize

In [None]:
#library(Rmisc)
library(dtplyr)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
library(data.table)
library("jsonlite")
library(ggthemes)

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

project_label="22_medical_records"
project_path = glue("{base_path}/results/projects/{project_label}")
figure_path = glue("{project_path}/figures")
output_path = glue("{project_path}/data")

experiment = 220627
experiment_path = glue("{output_path}/{experiment}")

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2), panel.grid.major=element_line()))

In [None]:
library(arrow)

## Load Data

In [None]:
endpoint_defs = arrow::read_feather(glue("{output_path}/phecode_defs_220306.feather")) %>% arrange(endpoint)

In [None]:
data_attrib_md = arrow::read_feather('/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220627/attributions_prepared.feather')

In [None]:
data_attrib_md %>% distinct(eid)

In [None]:
endpoint_selection = unique(data_attrib_md$endpoint)

In [None]:
endpoint_defs = endpoint_defs %>% 
    mutate(name = phecode_string) %>% 
    mutate(name = 
           case_when( 
               endpoint == "phecode_008"~"H. pylori", 
               endpoint == "phecode_092-2"~"Sepsis", 
               endpoint == "phecode_105"~"Breast cancer", 
                endpoint == "phecode_107-2"~"Prostate cancer", 
               endpoint == "phecode_123"~"Malignant plasma cell neoplasms", 
               endpoint == "phecode_164"~"Anemia", 
               endpoint == "phecode_200-1"~"Hypothyroidism", 
               endpoint == "phecode_232"~"Vitamin deficiencies", 
               endpoint == "phecode_284"~"Suicide attempt or self harm", 
               #endpoint == "phecode_287-5"~"Drug-induced psychosis", 
               endpoint == "phecode_324-11"~"Parkinson's",
               endpoint == "phecode_328"~"Dementia", 
               #endpoint == "phecode_404"~"Coronary heart disease", 
               endpoint == "phecode_424"~"Heart failure", 
               endpoint == "phecode_440-11"~"Deep vein thrombosis", 
               endpoint == "phecode_468"~"Pneumonia", 
               endpoint == "phecode_474"~"COPD", 
               endpoint == "phecode_518"~"Appendicitis", 
               endpoint == "phecode_542-1"~"Fibrosis and cirrhosis of liver", 
               endpoint == "phecode_583"~"Chronic kidney disease", 
               endpoint == "phecode_705-1"~"Rheumatoid arthritis", 
               endpoint == "phecode_908-1"~"(Pre)eclampsia", 
               #endpoint == "phecode_976"~"Complication of anesthesia",
               TRUE ~ name)
           )
endpoint_map = endpoint_defs$name
names(endpoint_map) =  endpoint_defs$endpoint
endpoint_order = (endpoint_defs %>% arrange(as.numeric(phecode)))$endpoint

In [None]:
#data_attrib_md %>% filter(str_detect(concept_name, "Psoriasis"))

## Verify link between attributions and outcomes

## Local

In [None]:
data_attrib_md_agg = data_attrib_md %>% #filter(eligable==1) %>% 
    group_by(endpoint, record, concept_name, n_records, n_events_record, freq_events_record, eligable) %>% 
    summarise(n_shapley=n(), mean_shapley = mean(shapley, na.rm=TRUE), sum_shapley = sum(shapley, na.rm=TRUE)) %>% 
    ungroup() %>% select(endpoint, record, concept_name, eligable, freq_events_record, n_shapley, mean_shapley, sum_shapley)

In [None]:
data_attrib_md_agg %>% arrange(endpoint, record, eligable) #%>% mutate(%>% filter(n_shapley>100)

In [None]:
library(ggpubr)

In [None]:
plot_width = 15; plot_height=6; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = data_attrib_md_agg %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>%
    #filter(n_records > 100) #%>% sample_n(10000)

attrib_plot = ggplot(temp, aes(x=mean_shapley, y=freq_events_record)) + 
    labs(title=NULL, x="Shapley Value", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    geom_smooth(method="lm") + 
    coord_cartesian(ylim=c(0, NA))+
    #scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    #s#cale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    #scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(20))), ncol=8) + 
    theme(legend.position="none") #+ 
    #stat_cor(method = "pearson")
attrib_plot

In [None]:
plot_width = 15; plot_height=6; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = data_attrib_md_agg %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>%
    #filter(n_records > 100) #%>% sample_n(10000)

attrib_plot = ggplot(temp, aes(x=sum_shapley, y=freq_events_record)) + 
    labs(title=NULL, x="Shapley Value", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    geom_smooth(method="lm") + 
    coord_cartesian(ylim=c(0, NA))+
    #scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    #s#cale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    #scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(20))), ncol=8) + 
    theme(legend.position="none") #+ 
    #stat_cor(method = "pearson")
attrib_plot

In [None]:
attribution_summary = temp %>% mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% 
    mutate(setting = case_when(eligable==1 ~ factor("2_noprior", levels=c("1_prior", "2_noprior")), eligable==0 ~ factor("1_prior", levels=c("1_prior", "2_noprior")))) %>%
    select(record, concept_name,  endpoint, setting, n_shapley, mean_shapley, sum_shapley) %>% 
    rename(`1_n_shapley` = `n_shapley`, `2_mean_shapley` = `mean_shapley`, `3_sum_shapley` = `sum_shapley`) %>%
    pivot_wider(names_from=c("endpoint", "setting"),values_from=c("1_n_shapley", "2_mean_shapley", "3_sum_shapley"), 
                names_glue = "{endpoint}\n{setting}\n{.value}", names_sort=TRUE) %>% 
    select(record, concept_name, sort(current_vars()))
    #arrange(desc(`All-Cause Death`))

In [None]:
attribution_summary

In [None]:
attribution_summary %>% write_csv("outputs/attributions_detail_all.csv")

# Very restricted

In [None]:
attribution_summary = temp %>% mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% 
    filter(n_shapley>=16) %>%
    mutate(setting = case_when(eligable==1 ~ factor("2_noprior", levels=c("1_prior", "2_noprior")), eligable==0 ~ factor("1_prior", levels=c("1_prior", "2_noprior")))) %>%
    select(record, concept_name,  endpoint, setting, n_shapley, mean_shapley, sum_shapley) %>% 
    rename(`1_n_shapley` = `n_shapley`, `2_mean_shapley` = `mean_shapley`, `3_sum_shapley` = `sum_shapley`) %>%
    pivot_wider(names_from=c("endpoint", "setting"),values_from=c("1_n_shapley", "2_mean_shapley", "3_sum_shapley"), 
                names_glue = "{endpoint}\n{setting}\n{.value}", names_sort=TRUE) %>% 
    select(record, concept_name, sort(current_vars()))
    #arrange(desc(`All-Cause Death`))

In [None]:
attribution_summary

In [None]:
attribution_summary %>% write_csv("outputs/attributions_detail_all_restricted.csv")

In [None]:
attribution_summary %>% write_csv("outputs/attributions_detail_all.csv")

In [None]:
attribution_local #%>% write_csv("outputs/attributions_detail.csv")

In [None]:
top10_long = temp %>% mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% group_by(endpoint) %>% slice_max(mean_shapley, n=10) %>% mutate(rank=row_number()) %>% ungroup() 

In [None]:
t(top10_long %>% select(endpoint, concept_name, rank) %>% pivot_wider(names_from="rank",values_from="concept_name")) %>% as.data.frame() %>% write_csv("outputs/attributions_local_overview.csv")

In [None]:
top10_long %>% select(endpoint, concept_name) %>% pivot_wider(names_from="concept_name")

In [None]:
temp %>% mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% filter(endpoint=="Heart failure") %>% arrange(mean_shapley)

## Global

In [None]:
data_attrib_md_agg = data_attrib_md %>% filter(eligable==1) %>% group_by(endpoint, record, concept_name, n_records, n_events_record, freq_events_record) %>% summarise(sum_shapley = sum(shapley, na.rm=TRUE)) %>% ungroup()

In [None]:
library(ggpubr)

In [None]:
plot_width = 15; plot_height=6; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = data_attrib_md_agg %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>%
    #filter(n_records > 100) #%>% sample_n(10000)

attrib_plot = ggplot(temp, aes(x=sum_shapley, y=freq_events_record)) + 
    labs(title=NULL, x="Shapley Value", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    geom_smooth(method="lm") + 
    coord_cartesian(ylim=c(0, NA))+
    #scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    #s#cale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    #scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(20))), ncol=8) + 
    theme(legend.position="none") #+ 
    #stat_cor(method = "pearson")
attrib_plot

In [None]:
temp %>% mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% 
    left_join(temp %>% group_by(record) %>% summarise(mean_n_records = mean(n_records), mean_freq_records = mean(n_records/500000)), by="record") %>% 
    select(record, concept_name, mean_n_records, mean_freq_records, endpoint, sum_shapley) %>% pivot_wider(names_from="endpoint",values_from="sum_shapley") %>% arrange(desc(`All-Cause Death`)) %>%
    write_csv("outputs/attributions_global_detail.csv")

In [None]:
top10_long = temp %>% mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% group_by(endpoint) %>% slice_max(sum_shapley, n=10) %>% mutate(rank=row_number()) %>% ungroup() 

In [None]:
t(top10_long %>% select(endpoint, concept_name, rank) %>% pivot_wider(names_from="rank",values_from="concept_name")) %>% as.data.frame() %>% write_csv("outputs/attributions_global_overview.csv")

In [None]:
top10_long %>% select(endpoint, concept_name) %>% pivot_wider(names_from="concept_name")

In [None]:
temp %>% mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% filter(endpoint=="Heart failure") %>% arrange(sum_shapley)

## Get Individual Attributions

In [None]:
plot_width = 10; plot_height=4; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)

endpoint = "phecode_424"
eid = (data_attrib_md %>% filter(endpoint == !!endpoint) %>% sample_n(1))[["eid"]]
print(eid)
#1487118 is nice!

temp_ind = data_attrib_md %>% filter(eid == !!eid) %>% filter(endpoint == !!endpoint) %>% mutate(concept_name = fct_reorder(concept_name, shapley)) %>% slice_max(abs(shapley), n=20)

ggplot(temp_ind, aes(x=concept_name, y=shapley, color=shapley<0)) + 
    labs(y="SHAPLEY", x="Record") + 
    geom_point() + 
    geom_segment(aes(xend=concept_name, y=0, yend=shapley)) + coord_flip() + 
    scale_x_discrete(labels = function(x) str_wrap(x, width = 40))

## Create sets for ablations

In [None]:
length(unique(data_attrib_md_agg$record))

In [None]:
data_attrib_md_agg %>% group_by()

In [None]:
plot_width = 15; plot_height=12; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = data_attrib_md %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

attrib_plot = ggplot(temp, aes(x=shap, y=freq_events_record)) + 
    labs(title=NULL, x="SHAP", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    geom_smooth(method="gam") + 
    #scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    #s#cale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    #scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(20))), ncol=6) + 
    theme(legend.position="none")
attrib_plot

In [None]:
plot_width = 15; plot_height=12; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = data_attrib_md %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

attrib_plot = ggplot(temp, aes(x=shap, y=jaccard)) + 
    labs(title=NULL, x="SHAP", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    geom_smooth(method="gam") + 
    #scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    #s#cale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    #scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(20))), ncol=6) + 
    theme(legend.position="none")
attrib_plot

In [None]:
require(scales)
library(ggrepel)

options(repr.plot.width=8.25, repr.plot.height=5, repr.plot.res=600)
rf_plot = ggplot(data_attrib_md, aes(x=n_records/n_eligable, y=shap)) + 
    labs(x="Record Frequency", y="SHAP [%]") + 
    #scale_alpha_manual(values = c("Yes" = 1, "No"=0.2))+
    #scale_color_manual(values = c("Yes" = "red", "No"="black"))+
    #scale_size_manual(values = c("Yes" = 1, "No"=0.2))+
    geom_smooth
    scale_x_log10(expand=c(0, 0), labels = trans_format("log10", math_format(10^.x)))+
    #geom_text_repel(box.padding = 0.5, max.overlaps = Inf, size=3, color="black", force=3) +
    geom_point(size=0.1, alpha=0.6)+#alpha=0.2, size=0.2) + 
    theme(legend.position="none")#+
    #scale_x_continuous(expand=c(0, 0))
rf_plot

In [None]:
plot_width = 15; plot_height=12; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = data_attrib_md %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

attrib_plot = ggplot(temp, aes(x=shap, y=freq_events)) + 
    labs(title=NULL, x="SHAP", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    geom_smooth(method="gam") + 
    #scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    #s#cale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    #scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(20))), ncol=6) + 
    theme(legend.position="none")
attrib_plot

In [None]:
plot_width = 15; plot_height=12; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = data_attrib_md %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

attrib_plot = ggplot(temp, aes(x=freq_events, y=shap)) + 
    labs(title=NULL, x="SHAP", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    geom_smooth(method="gam") + 
    #scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    #s#cale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    #scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free_x", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(20))), ncol=6) + 
    theme(legend.position="none")
attrib_plot

In [None]:
glue("{experiment_path}/shap_local.feather")

In [None]:
data_shap

In [None]:
endpoints = colnames(data_shap)[2:length(colnames(data_shap))]

## Load Metadata

In [None]:
concepts_raw = fread("/sc-projects/sc-proj-ukb-cvd/data/mapping/athena/CONCEPT.csv", quote="")

In [None]:
concept_defs = concepts_raw %>% as_tibble() %>% mutate(record = as.character(glue("OMOP_{concept_id}"))) %>% select(record, concept_name, domain_id)

In [None]:
endpoint_defs = arrow::read_feather(glue("{output_path}/phecode_defs_220306.feather")) %>% filter(endpoint %in% endpoints) %>% arrange(endpoint)

## Preparation

In [None]:
data_shap

In [None]:
data_shap_md = data_shap %>% left_join(data_records_agg) %>% left_join(concept_defs) %>% 
    rowwise() %>% mutate(overall = sum(c_across(OMOP_4306655:phecode_979), na.rm=TRUE)) %>% ungroup() %>% 
    select(record, concept_name, domain_id, n, freq, overall, everything())# %>% n, freq) %>% select(record, concept_name, domain_id,

In [None]:
data_shap_md %>% filter(n<10)

In [None]:
endpoint_defs %>% filter(str_detect(phecode_string, "Anemia"))

In [None]:
unique(data_shap_md$domain_id)

In [None]:
endpoint = "phecode_164"
data_shap_md %>% 
    select(concept_name, domain_id, n, freq, all_of(endpoint)) %>% 
    #filter(domain_id == "Drug") %>%
    filter(!is.na(concept_name)) %>%
    #filter(n>200) %>%
    arrange(desc(!!sym(endpoint))) 

In [None]:
endpoint = "phecode_175-2"

plot_width = 10; plot_height=5; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
endpoint_label = (endpoint_defs %>% filter(endpoint==!!endpoint))$phecode_string

temp = data_shap_md %>% mutate(endpoint = factor(endpoint, levels=endpoint_defs$endpoint)) %>% filter(n>50)
ggplot(temp, aes_string(x="record", y=endpoint, color=endpoint)) + 
    labs(title = endpoint_label, y="SHAP") +
    geom_point(size=0.2) +
    geom_text(data=temp %>% filter(abs(!!sym(endpoint))>0.15), 
              aes(label=stringr::str_wrap(concept_name, 30)), size=2, check_overlap = TRUE)+
    #scale_x_log10()+
    scale_colour_gradient2(low = "blue", mid = alpha("white", 0.01), high = "red", midpoint = 0, limits=c(-0.1, 0.1), oob=scales::squish)+
     #scale_color_viridis_c(option = "plasma")+
    theme(plot.title = element_text(hjust=0.5), legend.position="None")

In [None]:
data_shap_md %>% arrange(desc(overall)) %>% filter(n>10)

In [None]:
endpoint = "OMOP_4306655"

plot_width = 10; plot_height=5; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)

endpoint_label = (endpoint_defs %>% filter(endpoint==!!endpoint))$phecode_string
ggplot(data_shap_md, aes_string(x="freq", y=endpoint)) + 
    labs(title = endpoint_label, y="SHAP") +
    geom_point(size=0.2) +
    #scale_x_log10()+
    #scale_colour_gradient2(low = "blue", mid = alpha("white", 0.1), high = "red", midpoint = 0, limits=c(-0.1, 0.1), oob=scales::squish)+
     #scale_color_viridis_c(option = "plasma")+
    theme(plot.title = element_text(hjust=0.5))

In [None]:
endpoint = "OMOP_4306655"

plot_width = 10; plot_height=5; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
endpoint_label = (endpoint_defs %>% filter(endpoint==!!endpoint))$phecode_string

temp = data_shap_md %>% filter(n>20)
test = ggplot(temp, aes_string(x="record", y=endpoint, color=endpoint)) + 
    labs(title = endpoint_label, y="SHAP") +
    geom_point(size=0.2) +
    geom_text(data=temp %>% filter(abs(!!sym(endpoint))>0.15), 
              aes(label=stringr::str_wrap(concept_name, 30)), size=2, check_overlap = TRUE)+
    #scale_x_log10()+
    scale_colour_gradient2(low = "blue", mid = alpha("white", 0.01), high = "red", midpoint = 0, limits=c(-0.1, 0.1), oob=scales::squish)+
     #scale_color_viridis_c(option = "plasma")+
    theme(plot.title = element_text(hjust=0.5), legend.position="None")

In [None]:
library(plotly)

In [None]:
y = ggplotly(test, tooltip="concept_name")

In [None]:
htmlwidgets::saveWidget(y, "death_attributions.html")

## Compare agains event frequency
-> x = SHAP, y=event rate

In [None]:
endpoint_defs

In [None]:
endpoint = "OMOP_4306655"

In [None]:
plot_width = 10; plot_height=5; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
endpoint_label = (endpoint_defs %>% filter(endpoint==!!endpoint))$phecode_string
ggplot(data_shap_md, aes_string(x="freq", y=endpoint, color=endpoint)) + 
    labs(title = endpoint_label, y="SHAP") +
    geom_jitter(size=0.2) +
    #scale_x_log10()+
    scale_colour_gradient2(low = "blue", mid = alpha("white", 0.1), high = "red", midpoint = 0, limits=c(-0.1, 0.1), oob=scales::squish)+
     #scale_color_viridis_c(option = "plasma")+
    theme(plot.title = element_text(hjust=0.5))

In [None]:
plot_width = 10; plot_height=5; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
endpoint_label = (endpoint_defs %>% filter(endpoint==!!endpoint))$phecode_string
ggplot(data_shap_md, aes_string(x="freq", y=endpoint)) + 
    labs(title = endpoint_label, y="SHAP") +
    geom_point(size=0.2) +
    #scale_x_log10()+
    #scale_colour_gradient2(low = "blue", mid = alpha("white", 0.1), high = "red", midpoint = 0, limits=c(-0.1, 0.1), oob=scales::squish)+
     #scale_color_viridis_c(option = "plasma")+
    theme(plot.title = element_text(hjust=0.5))

In [None]:
data_shap_md %>% arrange(desc(OMOP_4306655)) %>% filter(n>10)

In [None]:
predictions = lazy_dt(bind_rows(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14)) 

In [None]:
plot_width = 10; plot_height=10; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
temp %>% ggplot(aes(x=logh, color=partition, fill=partition)) + geom_density(alpha=0.2)

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
colors_dict = read_json("colors.json")
color_map <- c(
    "Identity(AgeSex)+MLP" = colors_dict$pastel$red$mid,
    "Identity(Records)+MLP" = colors_dict$pastel$red$mid,
    "GNN(Records)+MLP" = colors_dict$pastel$red$mid,
    "Identity(AgeSex+Records)+MLP" = colors_dict$pastel$red$mid,
    "GNN(AgeSex+Records)+MLP" = colors_dict$pastel$red$mid
)

In [None]:
phecode_defs_path = "/sc-projects/sc-proj-ukb-cvd/data/mapping/phecodes/phecode_strings_V2.csv"
phecode_defs = fread(phecode_defs_path, colClasses=c("character", "character", "character", "character", "integer", "character", "integer"))#, dtype={"phecode": str}).sort_values("phecode")
phecode_defs = phecode_defs %>% add_row(phecode = "4306655", phecode_string = "All-Cause Death", phecode_category = "Death", sex="Both")
phecode_defs = phecode_defs %>% as_tibble %>% separate(phecode, into=c("first", "second"), remove=FALSE) %>% 
    mutate(comb = str_remove_all(glue("{first}-{second}"), "-NA")) %>%
    mutate(endpoint=case_when(comb == "4306655" ~ glue("OMOP_{comb}"),
                              TRUE ~ glue("phecode_{comb}"))) %>%
    select(phecode, endpoint, everything(), -first, -second, -comb)
phecode_defs %>% sample_n(5)

In [None]:
outcome_freq = data_outcomes %>% filter(prevalent==0) %>% group_by(endpoint) %>% summarize(freq = sum(event)/n()) %>% as_tibble()
outcome_freq %>% arrange(desc(freq))

In [None]:
endpoint_map = phecode_defs$phecode_string
names(endpoint_map) =  phecode_defs$endpoint
endpoint_order_freq = (outcome_freq %>% arrange(desc(freq)))$endpoint

## Load data

In [None]:
list.dirs(path = project_path, full.names = TRUE, recursive = TRUE)

# Figure 2: Selected Endpoints

## MedicalHistory state and incident disease

In [None]:
pred_outcomes = predictions %>% left_join(data_outcomes, on=c(eid, endpoint)) %>% as.data.table() 

In [None]:
logh_inc = pred_outcomes %>% filter(prevalent==0) %>% group_by(endpoint, model) %>% mutate(logh_perc = ntile(logh, 100)) %>% ungroup() %>% as_tibble()

## No buffer

In [None]:
logh_T_agg = logh_inc %>% group_by(endpoint, model, logh_perc) %>% summarise(ratio = mean(event)) %>% as_tibble()

In [None]:
logh_T_agg %>% write_feather(glue("{output_path}/logh_agg_220224.feather"))

In [None]:
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_T_agg %>% filter(model=="GNN(Records)+MLP") %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

mh_events = ggplot(temp, aes(x=logh_perc, y=ratio*100, color=logh_perc)) + 
    labs(title=NULL, x="Medical History Percentile [%]", y="Observed Event Rate [%]") +
    geom_point(alpha=0.7, size=0.1) + 
    scale_colour_gradient(
  low = "#7AC6FF",
  high = "#023768",
  space = "Lab",
  na.value = "grey50",
  guide = "colourbar",
  aesthetics = "colour"
)+
    #scale_color_manual(values=c("Metabolomics"="black"))+
    scale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistoryRisk"
mh_events %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

In [None]:
plot_name = "MedicalHistoryRisk"
mh_events %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

## Add buffer

In [None]:
pred_outcomes %>% head()

In [None]:
logh_T_agg_buffer = pred_outcomes %>% filter(prevalent==0) %>% mutate(event_buffer = case_when((event!=0&time<1) ~ 0, TRUE ~ event)) %>% group_by(endpoint, model) %>% mutate(logh_perc = ntile(logh, 100)) %>% group_by(endpoint, model, logh_perc) %>% summarise(ratio = mean(event_buffer)) %>% as_tibble()

In [None]:
logh_T_agg_buffer %>% write_feather(glue("{output_path}/logh_agg_1ybuffer_220224.feather"))

In [None]:
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_T_agg_buffer %>% filter(model=="GNN(Records)+MLP") %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

mh_events = ggplot(temp, aes(x=logh_perc, y=ratio*100, color=logh_perc)) + 
    labs(title=NULL, x="Medical History Percentile [%]", y="Observed Event Rate [%]") +
    geom_point(alpha=0.7, size=0.1) + 
    scale_colour_gradient(
  low = "#7AC6FF",
  high = "#023768",
  space = "Lab",
  na.value = "grey50",
  guide = "colourbar",
  aesthetics = "colour"
)+
    #scale_color_manual(values=c("Metabolomics"="black"))+
    scale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistoryRisk_1ybuffer"
mh_events %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

## MedicalHistory and Event Trajectories

In [None]:
logh_mh = logh_inc %>% select(endpoint, model, eid, logh_perc, event, time) %>% group_by(endpoint) %>% 
    mutate(MH=case_when(logh_perc %in% 91:100 ~ "High", 
                        logh_perc %in% 45:55 ~ "Mid", 
                        logh_perc %in% 1:10 ~ "Low",
                        TRUE ~ "NA")
          ) %>% mutate(MET = fct_rev(factor(MET, levels=c("Low", "Mid", "High")))) %>% ungroup() %>% 
    filter(MH!="NA") #%>% select(eid, endpoint, logh, logh_group)

In [None]:
require("ggquickeda")
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
met_map = c("High"="#023768", "Mid"="#4F8EC1", "Low"="#7AC6FF")

endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_mh %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() 

km_plot = ggplot(temp, aes(time = time, status = event, fill=MET, color=MET,group=MET)) +
    geom_km(trans = "event") + 
    geom_kmticks(trans = "event", size=0.3) + geom_kmband(trans = "event") + 
    labs(x="Time [Years]", y="Cumulative Events [%]")+
    scale_color_manual(values=met_map)+scale_fill_manual(values=met_map)+
    scale_y_continuous(labels = function(x) round(x*100, 1), expand=c(0, 0))+
    scale_x_continuous(expand=expansion(add=c(0, .1)), breaks=c(5, 10))+
    facet_wrap(~ endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistory_KMs"
km_plot %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

# Top 1%

In [None]:
logh_mh = logh_inc %>% select(endpoint, model, eid, logh_perc, event, time) %>% group_by(endpoint) %>% 
    mutate(MH=case_when(logh_perc == 100 ~ "High", 
                        logh_perc %in% 50:51 ~ "Mid", 
                        logh_perc == 1 ~ "Low",
                        TRUE ~ "NA")
          ) %>% mutate(MET = fct_rev(factor(MH, levels=c("Low", "Mid", "High")))) %>% ungroup() %>% 
    filter(MH!="NA") #%>% select(eid, endpoint, logh, logh_group)

In [None]:
require("ggquickeda")
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
mh_map = c("High"="#023768", "Mid"="#4F8EC1", "Low"="#7AC6FF")

endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_mh %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() 

km_plot = ggplot(temp, aes(time = time, status = event, fill=MH, color=MH,group=MH)) +
    geom_km(trans = "event") + 
    geom_kmticks(trans = "event", size=0.3) + geom_kmband(trans = "event") + 
    labs(x="Time [Years]", y="Cumulative Events [%]")+
    scale_color_manual(values=met_map)+scale_fill_manual(values=mh_map)+
    scale_y_continuous(labels = function(x) round(x*100, 1), expand=c(0, 0))+
    scale_x_continuous(expand=expansion(add=c(0, .1)), breaks=c(5, 10))+
    facet_wrap(~ endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistory_KMs_Top1"
km_plot %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

# Figure 2

In [None]:
plot_width=8.25; plot_height=10; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)
fig2 = met_events / km_plot# +plot_annotation(tag_levels = 'A')

In [None]:
fig2

In [None]:
library(gt)
plot_name = "Figures_2_AB"
fig2 %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=320)