# Benchmarks

## Initialize

In [None]:
#library(Rmisc)
#library(dtplyr)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
#library(data.table)
library("jsonlite")
library(ggthemes)

In [None]:
base_path = "/home/jakobs"

project_path = glue("{base_path}/data")
experiment = '230629'
experiment_path = glue("{project_path}/{experiment}")

endpoints_md = endpoint_defs = arrow::read_feather(glue("{base_path}/data/endpoints_epic_md.feather")) %>% filter(n_epic >=100)
endpoints = endpoint_defs$endpoint

In [None]:
library(data.table)
library(arrow)

In [None]:
endpoint_selection = c(
    # generally very important
    "phecode_202", # Diabetes mellitus
    "phecode_401",	#Hypertension"  
    "phecode_404", # Ischemic heart disease   
    "phecode_404-1", # Myocardial infarction [Heart attack]
    "phecode_431-11", # Cerebral infarction [Ischemic stroke]
    "phecode_424", # Heart failure

    
   # "phecode_059-1", # COVID 19
    "phecode_468", # Pneumonia
    "phecode_474", # Chronic obstructive pulmonary disease [COPD]
      
    "phecode_286-2", #	Major depressive disorder
    "phecode_324-11", #Parkinson's Disease
    "phecode_328", # Dementias and cerebral degeneration

    
    "phecode_164", # Anemia
    "phecode_726-1", # Osteoporosis
    "phecode_371", # Cataract
    #"phecode_374-42", # Diabetic retinopathy
    #"phecode_374-5", # Macular degeneration
    #"phecode_375-1", # Glaucoma
    
    
    "phecode_103", # Malignant neoplasm of the skin
    "phecode_101", # Malignant neoplasm of the digestive organs
    "phecode_102", # LUNG CANCER
    
    "phecode_583", # Chronic kidney disease    
    "phecode_542", # Chronic liver disease and sequelae    
    "OMOP_4306655" # All-Cause Death
    
    # also generally important and relevant
    #"phecode_440-3", # Pulmonary embolism
    #"phecode_468-1",	#Viral pneumonia
#     "phecode_460-2",	#Acute lower respiratory infection
    #"phecode_388" # Blindness and low vision
      # generally important and fun to check
   # "phecode_374-3", # Retinal vascular changes and occlusions
    #"phecode_665", # Psoriasis
#     "phecode_121", # Leukemia
    # important for eye
#     "phecode_705-1", # Rheumatoid arthritis
)

endpoints_common = c(
   'phecode_164', #Anemia
 'phecode_705-1', #Rheumatoid arthritis
 'phecode_328', #Dementias and cerebral degeneration
 'phecode_328-1', #Alzheimer's disease
 'phecode_401', #Hypertension
 'phecode_202', #Diabetes mellitus
 'phecode_416-21', #Atrial fibrillation
 'phecode_404-1', #Myocardial infarction [Heart attack]
 'phecode_424', #Heart failure
 'phecode_468', #Pneumonia
 'phecode_474', #Chronic obstructive pulmonary disease [COPD]
 'phecode_583', #Chronic kidney disease
 'OMOP_4306655' #All-Cause Death
    )
    
endpoints_cardio = c(
    'phecode_438-11',   #  "Abdominal aortic aneurysm",
    'phecode_440-3',#  "Pulmonary embolism", # intervention
    'phecode_413-21',#  "Aortic stenosis", # intervention
    'phecode_400'#  "Rheumatic fever and chronic rheumatic heart diseases",	
)

endpoints_eye= c(
    'phecode_374-5', #Macular degeneration
 'phecode_374-51', #Age-related macular degeneration
 'phecode_374-42', #Diabetic retinopathy
 'phecode_371', #Cataract
 'phecode_388', #Blindness and low vision
 'phecode_367-5', #Uveitis
 'phecode_389-1' #Ocular pain
)

In [None]:
endpoint_defs = endpoint_defs %>% 
    mutate(name = phecode_string) %>%
    mutate(name = 
           case_when( 
               phecode_string == "Myocardial infarction [Heart attack]"~"Myocardial infarction",
               phecode_string == "Cerebral infarction [Ischemic stroke]"~"Ischemic stroke",
               phecode_string == "Chronic obstructive pulmonary disease [COPD]"~"Chronic obstructive pulmonary disease",
               phecode_string == "Mitral valve insufficiency"~"Mitral insufficiency",
               phecode_string == "Parkinson's disease (Primary)"~"Parkinson's disease",
               phecode_string == "Suicide ideation and attempt or self harm"~"Suicide attempt",
               phecode_string == "Ischemic heart disease"~"Coronary heart disease",
               phecode_string == "Chronic kidney disease"~"Chronic kidney disease",
               phecode_string == "Rheumatic fever and chronic rheumatic heart diseases"~"Rheumatic heart disease",
               phecode_string == "Abdominal aortic aneurysm"~"Abdominal aortic aneurysm",
               #phecode_string == "Dementias and cerebral degeneration"~"Dementia",
                  TRUE ~ name)
           )
            
endpoint_map = endpoint_defs$name
names(endpoint_map) =  endpoint_defs$endpoint
#endpoint_order = (endpoint_defs %>% arrange(as.numeric(phecode)))$endpoint
endpoint_order = endpoint_selection

In [None]:
str_replace_all(endpoint_selection, "\\-", "\\.")

In [None]:
endpoints_md %>% filter(endpoint %in% endpoint_selection) %>% as_tibble() %>% arrange(n)  %>%
    mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>% mutate(perc = freq*100)

In [None]:
#today = substr(Sys.time(), 0, 10) # YYYY-MM-DD
today = '230629'

In [None]:
eligable_eids = arrow::read_feather(glue("{experiment_path}/eligible_eids_long_{today}.feather")) %>% 
    filter(endpoint %in% endpoint_selection) %>% 
    mutate(endpoint = as.character(endpoint)) %>%
    mutate(eid = as.character(eid)) %>%
    mutate(included = 1)

In [None]:
data_outcomes = arrow::read_feather(glue("{base_path}/data/data_outcomes_long_230320.feather", as_data_frame=FALSE)) %>% 
    filter(endpoint %in% endpoint_selection) %>% left_join(eligable_eids, by=c("eid", "endpoint"))

In [None]:
partitions = 0:9
paths = c()
for (p in partitions){
    #temp_path = glue("{experiment_path}/loghs/Identity(Records)+MLP/{p}/test.feather")
    temp_path = glue("{experiment_path}/loghs/RetinaUKB/{p}/test.feather")
    paths = c(paths, temp_path)
    }

In [None]:
predictions = paths %>% map_df(
    ~suppressWarnings(read_feather(., col_select=c("eid", all_of(endpoint_selection)))) %>% 
    pivot_longer(all_of(endpoint_selection), names_to="endpoint", values_to="logh") %>%
    mutate(eid = as.character(eid)) %>%
    select(endpoint, eid, logh)) %>% filter(endpoint %in% endpoints_md$endpoint)# %>% arrange(endpoint, eid)
predictions %>% head()

In [None]:
base_size = 8
title_size = 10
facet_size = 7
geom_text_size=3
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

# Figure 2: Selected Endpoints

In [None]:
pred_outcomes = predictions %>% left_join(data_outcomes, by=c("eid", "endpoint")) %>% as_tibble()

## Endpoint Prevalence + Rate Ratios

In [None]:
logh_inc = pred_outcomes %>% filter(included==1) %>% group_by(endpoint) %>% mutate(logh_perc = ntile(logh, 10)) %>% ungroup() %>% as_tibble()

In [None]:
logh_T_agg = logh_inc %>% group_by(endpoint, logh_perc) %>% summarise(n=sum(event), ratio = mean(event)) %>% as_tibble()

In [None]:
temp = logh_T_agg %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% 
    filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)
temp %>% filter(logh_perc %in% c(1, 10)) %>% 
    pivot_wider(names_from=logh_perc, values_from=c(n, ratio)) %>% 
    mutate(ratio=ratio_10/ratio_1) %>%
    mutate(endpoint = recode(endpoint, !!!endpoint_map)) %>%
    #mutate(ratio_1=ratio_1*100, ratio_10 = ratio_10*100) %>% 
    arrange(ratio)

## Endpoint selection for health state and incident disease rates

In [None]:
pred_outcomes = predictions %>% left_join(data_outcomes, by=c("eid", "endpoint")) %>% as.data.table() 

## Example illustration with deciles

In [None]:
logh_inc = pred_outcomes %>% filter(included==1) %>% group_by(endpoint) %>% mutate(logh_perc = ntile(logh, 10)) %>% ungroup() %>% as_tibble()

In [None]:
logh_T_agg = logh_inc %>% group_by(endpoint, logh_perc) %>% summarise(n=sum(event), ratio = mean(event)) %>% as_tibble()

In [None]:
#logh_T_endpoint = logh_inc %>% group_by(endpoint) %>% summarise(n_all=sum(event), ratio_all = mean(event)) %>% as_tibble()

In [None]:
endpoint_order

In [None]:
temp

In [None]:
plot_width = 2; plot_height=2; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
#endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = logh_T_agg %>% filter(endpoint=="OMOP_4306655") %>% ungroup() %>% arrange(endpoint) #%>% sample_n(10000)

fig2a = ggplot(temp, aes(x=as.numeric(as.character(logh_perc)), y=ratio*100, color=logh_perc)) + 
    labs(title="All-cause Death", x="Risk Decile", y="Incident Events (%)") +
    geom_line(alpha=0.7, size=0.3) +    
    geom_point(alpha=0.7, size=0.3) + 
    geom_point(data=temp %>% filter(logh_perc %in% c(1, 10)), alpha=1, size=1, color="black") + 
    geom_hline(aes(yintercept=0.038674033*100), alpha=0.3, linetype="22", size=0.25) + 
    geom_hline(aes(yintercept=0.002758621*100), alpha=0.3, linetype="22", size=0.25) + 
    coord_cartesian(ylim=c(0, 10))+
    geom_text(data=temp %>% filter(logh_perc==1), mapping=aes(label=glue("Bottom 10%: {n} ({round(ratio*100, 1)}%)")), color="black", size=2, x=1, y=9, hjust=0) + 
    geom_text(data=temp %>% filter(logh_perc==10), mapping=aes(label=glue("Top 10%: {n} ({round(ratio*100, 1)}%)")), color="black", size=2, x=1, y=10, hjust=0) + 
    geom_text(data=temp %>% pivot_wider(names_from=logh_perc, values_from=c("n", "ratio")), 
              mapping=aes(label=glue("Rate Ratio: ({round(ratio_10*100, 1)}% / {round(ratio_1*100, 1)}%)")), 
              color="black", size=2, x=1, y=8, hjust=0, alpha=0.9) + 
    geom_text(data=temp %>% pivot_wider(names_from=logh_perc, values_from=c("n", "ratio")), 
              mapping=aes(label=glue("Rate Ratio: ~ {round(ratio_10/ratio_1, 1)}")), 
              color="black", size=3, x=1, y=7, hjust=0, alpha=0.9) +
    scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    scale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(.02, .02)))+#, limits=c(0, NA))+
    scale_x_continuous(expand=expansion(add=c(0.25, 0.25)), breaks = c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))+
    theme(legend.position="none", plot.title = element_text(hjust = 0.5))
fig2a

## Real Figure with percentiles

In [None]:
logh_inc = pred_outcomes %>% filter(included==1) %>% group_by(endpoint) %>% mutate(logh_perc = ntile(logh, 10)) %>% ungroup() %>% as_tibble()

In [None]:
logh_T_agg = logh_inc %>% group_by(endpoint, logh_perc) %>% summarise(n=sum(event), ratio = mean(event)) %>% as_tibble()

In [None]:
#logh_T_endpoint = logh_inc %>% group_by(endpoint) %>% summarise(n_all=sum(event), ratio_all = mean(event)) %>% as_tibble()

In [None]:
endpoint_order

In [None]:
plot_width = 4; plot_height=9; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
#endpoint_order = (endpoint_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
temp = logh_T_agg %>% filter(endpoint %in% endpoint_selection) %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% ungroup() %>% arrange(endpoint) #%>% sample_n(10000)

fig2a = ggplot(temp, aes(x=logh_perc, y=ratio*100, color=logh_perc)) + 
    labs(title=NULL, x="Risk Percentile (%)", y="Incident Events (%)") +
    geom_point(alpha=0.7, size=0.3) + 
    scale_colour_gradient(low = "#7AC6FF", high = "#023768", space = "Lab", na.value = "grey50", guide = "colourbar", aesthetics = "colour")+
    scale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(22))), ncol=3) + theme(legend.position="none")
fig2a

In [None]:
temp %>% write_csv(glue("outputs/SupplFigure2b.csv"))

# Kaplan Mayer

In [None]:
logh_inc = pred_outcomes %>% filter(included==1) %>% group_by(endpoint) %>% mutate(logh_perc = ntile(logh, 100)) %>% ungroup() %>% as_tibble()

In [None]:
logh_ret = logh_inc %>% select(endpoint, eid, logh_perc, event, time) %>% group_by(endpoint) %>% 
    mutate(RET=case_when(logh_perc %in% 91:100 ~ "High", 
                        logh_perc %in% 45:55 ~ "Mid", 
                        logh_perc %in% 1:10 ~ "Low",
                        TRUE ~ "NA")
          ) %>% mutate(MET = fct_rev(factor(RET, levels=c("Low", "Mid", "High")))) %>% ungroup() %>% 
    filter(RET!="NA") #%>% select(eid, endpoint, logh, logh_group)

In [None]:
source("ggkm.R")

In [None]:
plot_width = 8.25; plot_height=4.5; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
met_map = c("High"="#023768", "Mid"="#4F8EC1", "Low"="#7AC6FF")

temp = logh_ret %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() 

fig2b = ggplot(temp, aes(time = time, status = event, fill=RET, color=RET, group=RET)) +
    geom_km(trans = "event") + 
    geom_kmticks(trans = "event", size=0.2, alpha=0.01) + 
    geom_kmband(trans = "event") + 
    labs(x="Time (Years)", y="Cumulative Events (%)")+
    scale_color_manual(values=met_map)+scale_fill_manual(values=met_map)+
    scale_y_continuous(labels = function(x) round(x*100, 1), expand=c(0, 0))+
    scale_x_continuous(expand=expansion(add=c(0, .1)), breaks=c(5, 10))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = as_labeller(endpoint_map, default=label_wrap_gen(22))), 
               ncol=6) + theme(legend.position="none")
fig2b

In [None]:
temp %>% write_csv(glue("outputs/SupplFigure2c.csv"))

In [None]:
plot_name = "SupplFigure2bc_EPIC_EventRates_Selected"
fig2b %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)
fig2b %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)