# Benchmarks

## Initialize

In [None]:
#library(Rmisc)
library(dtplyr)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
library(data.table)
library("jsonlite")
library(ggthemes)

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

project_label="22_medical_records"
project_path = glue("{base_path}/results/projects/{project_label}")
figure_path = glue("{project_path}/figures")
output_path = glue("{project_path}/data")

experiment = 220627
experiment_path = glue("{output_path}/{experiment}")

In [None]:
concept = fread("/sc-projects/sc-proj-ukb-cvd/data/mapping/athena/CONCEPT.csv")

## Generate Data

In [None]:
record_freqs = arrow::read_feather(glue("{experiment_path}/records_inc_disease_freq.feather")) %>%
    rename(concept_id=record) %>% mutate(concept_id = str_replace(concept_id, "OMOP_", "")) %>%
    left_join(concept %>% mutate(concept_id = as.character(concept_id)) %>% as_tibble(), on="concept_id")

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2), panel.grid.major = element_line()))

In [None]:
record_freqs = record_freqs %>% 
    mutate(highlight = factor(case_when(str_detect(concept_name, 
                                                   "Heart failure$|Dilated cardiomyopathy|Type 2 diabetes mellitus$|Essential hypertension|^Rheumatoid arthritis$|Diabetic glomerulonephritis|Portal hypertension") ~"Yes", 
                                        TRUE ~"No"))) %>% 
    mutate(label=case_when(highlight=="Yes" ~ tools::toTitleCase(concept_name), TRUE ~"")) %>% #%>% filter(highlight==1)%>%
    mutate(freq_records = n_records/n_eligable) %>%
    mutate(log_freq_records = log(freq_records))

In [None]:
record_freqs %>% filter(highlight=="Yes")

In [None]:
record_freqs %>% ggplot() + geom_density(aes(x=log(freq_records)))

In [None]:
record_freqs

In [None]:
n1 = record_freqs %>% filter(freq_records <0.01) %>% group_by(domain_id) %>% tally()

In [None]:
n2 = record_freqs %>% filter(freq_records <0.01, freq_events_record>0.2) %>% group_by(domain_id) %>% tally()

In [None]:
bind_cols(n1, n2) %>% mutate(perc =`n...4`/`n...2`)

In [None]:
record_freqs %>% filter(freq_records <0.01, freq_events_record>0.2) %>% filter(str_detect(concept_name, "Portal hypertension")) #%>% group_by(domain_id) %>% slice_max(freq_events_record, n=5)

In [None]:
record_freqs %>% filter(n_records>3000&n_records<4000) %>% arrange(desc(freq_events_record)) #%>% filter(freq_events_record<0.24)

In [None]:
library(ggrepel)

In [None]:
library(scales)

In [None]:
require(scales)
plot_width=8.25; plot_height=4; plot_res=320
options(repr.plot.width=plot_width, repr.plot.height=plot_height, repr.plot.dpi=plot_res)
rf_plot = ggplot(record_freqs, aes(id=concept_id, name=concept_name, domain=domain_id, x=log_freq_records, y=freq_events_record, color=highlight, alpha=highlight, size=highlight, label=label)) + 
    labs(x="Individuals with prior record (%)", y="Mortality rate (%)") + 
    scale_alpha_manual(values = c("Yes" = 1, "No"=0.2))+
    scale_color_manual(values = c("Yes" = "red", "No"="black"))+
    scale_size_manual(values = c("Yes" = 2, "No"=1))+
   scale_y_continuous(expand=c(0, 0), 
                       breaks = c(0.25, 0.5, 0.75, 1),
                      labels = c("25", "50", "75", "100")) + 
    scale_x_continuous(expand=c(0, 0), 
                       breaks = c(log(1/10000), log(1/1000), log(1/100), log(1/10)),
                      labels = c("0.01", "0.1", "1", "10"))+#, labels = percent)+#, labels = trans_format("log10", math_format(10^.x)))+
    geom_label_repel(box.padding = 0.8, max.overlaps = Inf, size=3, color="black", force=3, min.segment.length=0) +
    geom_point()+#alpha=0.2, size=0.2) + 
   # geom_point(data = record_freqs %>% filter(highlight=="Yes"), aes(x=log_freq_records, y=freq_events_record), alpha=0.2, color="red", size=1), +#alpha=0.2, size=0.2) + 
    theme(legend.position="none")+#+
    coord_cartesian(ylim=c(0, 0.75))
    #scale_x_continuous(expand=c(0, 0))
rf_plot

In [None]:
library(gt)
plot_name = "SupplFigure1b_recordsmortality"
rf_plot %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device=cairo_pdf, width=plot_width, height=plot_height, dpi=plot_res)

## Check how many have rare records?

In [None]:
data_records = arrow::read_feather(glue("{output_path}/baseline_records_220627.feather", as_data_frame=FALSE)) 

In [None]:
data_records_freq = data_records %>% summarise(across(-eid, ~ mean(.x))) %>% pivot_longer(everything(),names_to="concept_id", values_to="freq") %>% arrange(freq)

## Simpler

In [None]:
data_records_freq_binned = data_records_freq %>% 
    mutate(freq_bin = case_when(
        freq >= 0.01 ~"Common",
        freq < 0.01 ~ "Rare"
        #freq < 0.001 & freq >=0.0001 ~ "Very Rare",
        #freq < 0.0001 ~ "Extremely Rare"
    )
          ) %>%  
    as_tibble() %>%
    left_join(concept %>% as_tibble() %>% mutate(concept_id = glue("OMOP_{concept_id}")))

In [None]:
bin = "Common"
bin_selection_df = data_records_freq_binned %>% filter(str_detect(concept_id, "OMOP_")) %>% filter(freq_bin==!!bin) 
record_selection = bin_selection_df$concept_id
print(length(record_selection))
n_verycommon = (data_records %>% select(all_of(record_selection)) %>% mutate(n = rowSums(across(all_of(record_selection)))))$n

In [None]:
bin = "Rare"
bin_selection_df = data_records_freq_binned %>% filter(str_detect(concept_id, "OMOP_")) %>% filter(freq_bin==!!bin) 
record_selection = bin_selection_df$concept_id
print(length(record_selection))
n_common = (data_records %>% select(all_of(record_selection)) %>% mutate(n = rowSums(across(all_of(record_selection)))))$n

In [None]:
overview = data_records %>% select(eid)
overview["n_common"] = n_verycommon
overview["n_rare"] = n_common
#overview["n_veryrare_n9944"] = n_veryrare
#overview["n_extremelyrare_n55544"] = n_extremelyrare

In [None]:
temp_plot = overview %>% mutate(ratio = n_rare/n_common)

In [None]:
library(ggdist)
temp_plot %>% median_qi(ratio, na.rm=TRUE)

In [None]:
options(repr.plot.width=3, repr.plot.height=2, repr.plot.res=600)
ratio = ggplot(temp_plot, aes(x=log(ratio))) +
    labs(x="Rare versus common concepts (Ratio)", y="Density") + 
    geom_density(fill="black", alpha=0.1) + 
    scale_x_continuous(expand=c(0, 0), 
                       breaks = c(log(0.1), log(1), log(10)), 
                       labels = c("0.1", "1", "10")) +
    scale_y_continuous(expand=c(0, 0))
ratio

In [None]:
temp = overview %>% pivot_longer(-eid, names_to="type", values_to="count")

In [None]:
temp 

In [None]:
temp %>% mutate(n_bigger0 = ifelse(count>0, TRUE, FALSE), n_bigger10 = ifelse(count>10, TRUE, FALSE), n_bigger100 = ifelse(count>100, TRUE, FALSE)) %>% 
    group_by(type) %>% summarise(n_bigger0_ratio = mean(n_bigger0),
                                 n_bigger10_ratio = mean(n_bigger10),
                                 n_bigger100_ratio = mean(n_bigger100)) %>%
    mutate(type = factor(type, levels=c("n_common", "n_rare"))) %>% arrange(type)                 

In [None]:
temp_plot = temp %>% group_by(type, count) %>% summarise(n_count = n()) %>% 
    mutate(cumsum = cumsum(n_count)) %>% 
    mutate(rev_cumsum=502460-cumsum, freq_rev_cumsum = (502460-cumsum)/502460) %>% mutate(n_unique_records = count+1) %>% 
    select(type, n_unique_records, rev_cumsum,freq_rev_cumsum) %>%
    mutate(log_n_unique_records = log(n_unique_records)) %>%
    mutate(type = fct_rev(factor(type, levels=c("n_common", "n_rare"))))

In [None]:
type_map = c(
    "n_common" = "Common (>= 1%, n = 1.186)",
    "n_rare" = "Rare (< 1%, n = 69.850)"
    #"n_extremelyrare_n55544" = "Extremely rare\n(<0.01%, n = 55544)"
)

In [None]:
options(repr.plot.width=8.25, repr.plot.height=2, repr.plot.res=600)
minimum = ggplot(temp_plot, aes(x=log_n_unique_records, y=freq_rev_cumsum*100)) + 
    labs(x="Unique concepts before recruitment (n)", y="Individuals (%)") + 
    geom_ribbon(aes(xmin=log_n_unique_records, ymin=0, ymax=freq_rev_cumsum*100), alpha=0.1) +#
    geom_line(color="black") + 

    geom_segment(data=temp_plot %>% filter(n_unique_records==1), 
                 mapping=aes(x=log(1), xend=log(3), y=freq_rev_cumsum*100, yend=freq_rev_cumsum*100), 
                 alpha=0.5, arrow = arrow(length = unit(0.1, "cm"), type = "closed"), size=0.25) + 
    geom_text(data=temp_plot %>% filter(n_unique_records==1), 
              mapping=aes(label=glue("{round(freq_rev_cumsum*100, 1)}%"), y=freq_rev_cumsum*100), 
              size=2.5, x=log(4), hjust=0) + 

    geom_segment(data=temp_plot %>% filter(n_unique_records==10), 
                 mapping=aes(x=log(10), xend=log(20), y=freq_rev_cumsum*100, yend=freq_rev_cumsum*100), 
                 alpha=0.5, arrow = arrow(length = unit(0.1, "cm"), type = "closed"), size=0.25) + 
    geom_text(data=temp_plot %>% filter(n_unique_records==10), 
              mapping=aes(label=glue("{round(freq_rev_cumsum*100, 1)}%"), y=freq_rev_cumsum*100), 
              size=2.5, x=log(25), hjust=0) + 

    scale_x_continuous(expand=c(0, 0), breaks = c(log(1), log(10), log(100)), labels = c("1", "10", "100")) +
    scale_y_continuous(expand=c(0, 0)) +# expand_limits(y=c(0, 5)) + 
    coord_cartesian(xlim=c(log(1), log(200)), ylim=c(0, 100)) + 
    facet_grid(~type, labeller=labeller(type=type_map)) +
    theme(legend.position="none")

In [None]:
plot_width=8.25; plot_height=6; plot_res=320
options(repr.plot.width=plot_width, repr.plot.height=plot_height, repr.plot.dpi=plot_res)
supplfig1bc = rf_plot / (minimum + ratio) + plot_layout(height=c(0.8, 0.20))
supplfig1bc 

In [None]:
library(gt)
plot_name = "SupplFigure1bc_recordsmortalityratio"
supplfig1bc %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device=cairo_pdf, width=plot_width, height=plot_height, dpi=plot_res)

In [None]:
data_records_freq_binned %>% filter(freq>0)

In [None]:
library(plotly)
rf_plot_plotly = ggplotly(rf_plot)

In [None]:
htmlwidgets::saveWidget(rf_plot_plotly, "record_frequencies.html")

### Funnelplot

In [None]:
temp_funnel = record_freqs %>% rename(d="n") %>% mutate(n=d*event) %>% select(concept_id, concept_name, domain_id, n, d)
temp_funnel

In [None]:
(data_prep %>% summarise(rate = mean(event)))

In [None]:
funnel_scores <- funscore(input=temp_funnel, 
                     benchmark=0.0781216415237034, 
                     alpha=0.80, 
                     alpha2=0.95, 
                     method='exact')

funnel_limits   <- fundata(input=temp_funnel, 
                      benchmark=0.0781216415237034, 
                      alpha=0.80, 
                      alpha2=0.95, 
                      method='exact', 
                      step=1)

In [None]:
funnel_scores %>% head()

In [None]:
funnel_scores %>% filter(n<5000)

In [None]:
funnel_scores %>% filter(n<5000) %>% group_by(score2) %>% tally()

In [None]:
options(repr.plot.width=10, repr.plot.height=6, repr.plot.res=600)
rf_plot = ggplot(funnel_scores, aes(x=d, y=r))+#, color=highlight, alpha=highlight, size=highlight)) + 
    labs(x="Record Frequency", y="Mortality Rate [%]") + 
    #scale_alpha_manual(values = c("Yes" = 1, "No"=0.2))+
    #scale_color_manual(values = c("Yes" = "red", "No"="black"))+
   # scale_size_manual(values = c("Yes" = 1, "No"=0.2))+
    geom_hline(data=funnel_limits, aes(yintercept=benchmark), colour="red") +
    #geom_smooth(data=funnel_limits, aes(x=d, y=up))+
    #geom_smooth(data=funnel_limits, aes(x=d, y=lo))+
    geom_line(data=funnel_limits, aes(x=d, y=up2), color="orange")+
    geom_line(data=funnel_limits, aes(x=d, y=lo2), color="orange")+
    scale_x_continuous(trans="log10", expand=c(0, 0))+
    coord_cartesian(xlim=c(25, NA))+
    #geom_text_repel(box.padding = 0.5, max.overlaps = Inf, size=3, color="black", force=3) +
    geom_point(size=0.2, aes(alpha=score2, color=score2)) + 
    theme(legend.position="none")+
    scale_color_manual(values = c("Extreme" = "black", "In Control"="black"))+
    scale_alpha_manual(values = c("Extreme" = 0.5, "In Control"=0.1)) #+ 
    #geom_smooth(method="loess", data=funnel_scores, aes(x=d, y=r), color="blue", linetype="22")
    #scale_x_continuous(expand=c(0, 0))
rf_plot

In [None]:
library(funnelR)



In [None]:
my_plot4_mod <- my_plot +
                labs(x="Physician practice size", y="Proportion (%) of satisfied patients") +
                geom_hline(yintercept=0.40, colour="darkred", linetype=6, size=1) +               
                theme_minimal() +
                scale_colour_manual(values=c("green","darkgreen")) + 
                geom_text(aes(label=id), colour="black", size=4, nudge_x=10) 
                

my_plot4_mod

In [None]:
library(funnelR)

my_limits   <- fundata(input=temp_funnel, 
                      benchmark=0.0781216415237034, 
                      alpha=0.80, 
                      alpha2=0.95, 
                      method='exact', 
                      step=1)

my_plot     <- funplot(input=temp_funnel,
                       
                       fundata=my_limits)

my_plot

# Hazard Ratios

In [None]:
data_prep %>% head()

In [None]:
library(foreach)
library(doParallel)
registerDoParallel(cores=20)

In [None]:
r = "OMOP_4081598"

In [None]:
data_temp = arrow::read_feather("/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/data_prep_recordfrequencies_220412.feather",
                                   col_select=c(eid, age_at_recruitment_f21022_0_0, sex_f31_0_0, !!r, event, time)) 

In [None]:
str(data_temp)

In [None]:
library("survival")
fit_cox = function(r){
    data_temp = arrow::read_feather("/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/data_prep_recordfrequencies_220412.feather",
                                   col_select=c(eid, age_at_recruitment_f21022_0_0, sex_f31_0_0, !!r, event, time)) 
    f = as.formula(glue("Surv(time, event) ~ age_at_recruitment_f21022_0_0+sex_f31_0_0+{r}"))
    cox <- coxph(f, data = data_temp)
    rm(data_temp)
    return (cox)
}

In [None]:
library(broom)

coxsummaries = list()
coxcoefs = list()
record_list = record_ids
for (r in record_list){
    cox = fit_cox(r)
    coxsummaries[[r]] = glance(cox) %>% mutate(record=r) 
    coxcoefs[[r]] = tidy(cox) %>% mutate(record=r)
}

In [None]:
coxsummaries = bind_rows(coxsummaries) %>% select(record, everything()) 
coxsummaries %>% write_feather("rf_coxsummaries_220413.feather")

In [None]:
coxcoefs = bind_rows(coxcoefs) %>% select(record, everything()) 
coxcoefs %>% write_feather("rf_coxcoefs_220413.feather")

In [None]:
temp_coxcoefs = coxcoefs %>% filter(str_detect(term, "OMOP")) 

In [None]:
temp_hr = temp_coxcoefs %>% mutate(HR = exp(estimate)) %>% left_join(record_frequencies)

In [None]:
temp_hr

In [None]:
options(repr.plot.width=10, repr.plot.height=6, repr.plot.res=600)
hr_plot = ggplot(temp_hr %>% filter(domain_id %in% c("Condition", "Drug", "Procedure")), aes(id=concept_id, name=concept_name, domain=domain_id, x=n/502460, y=HR)) + 
    labs(x="Record Frequency", y="Adj. HR [%]") + 
    #scale_alpha_manual(values = c("Yes" = 1, "No"=0.2))+
    #scale_color_manual(values = c("Yes" = "red", "No"="black"))+
    #scale_size_manual(values = c("Yes" = 1, "No"=0.2))+
    #scale_y_log10(expand=c(0, 0))+
    scale_x_log10(expand=c(0, 0))+
    #geom_text_repel(box.padding = 0.5, max.overlaps = Inf, size=3, color="black", force=3) +
    geom_point(size=0.2) + 
    theme(legend.position="none")#+
    #scale_x_continuous(expand=c(0, 0))
hr_plot#+ annotation_logticks()

In [None]:
options(repr.plot.width=10, repr.plot.height=6, repr.plot.res=600)
hr_plot = ggplot(temp_hr %>% filter(domain_id %in% c("Condition", "Drug", "Procedure")), aes(id=concept_id, name=concept_name, domain=domain_id, x=n/502460, y=HR)) + 
    labs(x="Record Frequency", y="Adj. Hazard Ratio for Record") + 
    #scale_alpha_manual(values = c("Yes" = 1, "No"=0.2))+
    #scale_color_manual(values = c("Yes" = "red", "No"="black"))+
    #scale_size_manual(values = c("Yes" = 1, "No"=0.2))+
    scale_x_log10(expand=c(0, 0))+
    #scale_x_log10(expand=c(0, 0))+
    #geom_text_repel(box.padding = 0.5, max.overlaps = Inf, size=3, color="black", force=3) +
    geom_point(size=0.2, alpha=0.1) + 
    geom_hline(yintercept=1, linetype="22", color="red") +
    theme(legend.position="none")#+
    #scale_x_continuous(expand=c(0, 0))
hr_plot#+ annotation_logticks()

In [None]:
exp(rnorm(100, mean=0, sd=1))

In [None]:
temp_hr %>% select(record, estimate, HR, n, concept_name, domain_id) %>% arrange(desc(HR)) %>% filter(domain_id %in% c("Condition", "Drug", "Procedure"))

In [None]:
temp_hr

In [None]:
library("metafor")

In [None]:
funnel(x=temp_hr$estimate, sei=temp_hr[["std.error"]], ni=temp_hr$n, yaxis="sei", size=0.1)

In [None]:
temp_funnel = temp_hr %>% mutate(n=HR, d=n)
temp_funnel

In [None]:
library(funnelR)


my_limits   <- fundata(input=temp_funnel, 
                      benchmark=1, 
                      alpha=0.80, 
                      alpha2=0.95, 
                      method='exact', 
                      step=1)

my_limits 

In [None]:
my_plot     <- funplot(input=temp_funnel, 
                       fundata=temp_funnel)

my_plot

In [None]:
record_frequency

In [None]:
library(broom)

In [None]:
glance(cox)

In [None]:
tidy(cox)

In [None]:
res.cox

In [None]:
data_prep %>% select(eid, age_at_recruitment_f21022_0_0, sex_f31_0_0, OMOP_1000560, event)

In [None]:
record_freqs %>% filter(n>20, event>0.5) %>% arrange(desc(event)) %>% filter(domain_id=="Condition")

In [None]:
mi_data = data_outcomes %>% filter(endpoint=="phecode_404") 
data_records_mi = data_records %>% left_join(mi_data, on="eid") %>% filter(prevalent == 0) %>% select(-endpoint, -prevalent, -time)

In [None]:
data_records_mi

In [None]:
dfs = list()
i=1
for (subset in record_ids_subsets){
    print(i)
    temp_freq = data_records_mi %>% 
        select(eid, all_of(subset), event) %>% 
        pivot_longer(starts_with("OMOP_"), names_to="concept_id", values_to="record") %>% 
        filter(record==1) %>% group_by(concept_id) %>% summarise(n=n(), event=mean(event))
    dfs[[i]] = temp_freq
    i = i+1
    flush.console()
    }

In [None]:
record_freqs = bind_rows(dfs) %>% mutate(concept_id = str_remove_all(concept_id, "OMOP_")) %>% left_join(concept %>% mutate(concept_id = as.character(concept_id)), on="concept_id")

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2), panel.grid.major = element_line()))

In [None]:
record_freqs = record_freqs %>% mutate(highlight = factor(case_when(str_detect(concept_name, "simvastatin|Type 2 diabetes mellitus|smok|aspirin|Essential hypertension") ~"Yes", TRUE ~"No"))) %>% mutate(freq = n/nrow(data_records)) #%>% filter(highlight==1)

In [None]:
record_freqs %>% filter(highlight=="Yes")

In [None]:
options(repr.plot.width=10, repr.plot.height=5, repr.plot.res=320)
ggplot(record_freqs, aes(x=freq, y=event, color=highlight, alpha=highlight, size=highlight)) + 
    scale_alpha_manual(values = c("Yes" = 1, "No"=0.1))+
    scale_color_manual(values = c("Yes" = "red", "No"="black"))+
    scale_size_manual(values = c("Yes" = 1, "No"=0.1))+
    geom_point()+
    scale_x_continuous(expand=c(0, 0))

In [None]:
Acute myocardial infarction

In [None]:
record_freqs 

In [None]:
data_records_death_long_subset

In [None]:
library(svMisc)

In [None]:
endpoint_label = record_ids[1]

In [None]:
endpoint_label

In [None]:
record_baselines = tibble(concept_id=NA, n = NA, death = NA)

for (i in 1:length(record_ids)){
    progress(i)
    endpoint_label = record_ids[i]
    temp_record = (data_records_death %>% filter(!!sym(endpoint_label)==1))
    record_baselines %>% add_row(concept_id = id, n=nrow(temp_record), death = mean(temp_record$event))
    flush.console()   
    }

In [None]:
record_baselines

In [None]:
mean((data_records_death %>% filter(OMOP_1000772==1))&event)

In [None]:
temp_record = (data_records_death %>% filter(OMOP_1000772==1))

In [None]:
n_record = nrow(temp_record)

In [None]:
nrow(data_records_death)*ncol(data_records_death)

In [None]:
data_records_death %>% select(starts_with("OMOP_")) 

In [None]:
ggplot(temp, aes()

In [None]:
model = "GNN(Records)+MLP"

In [None]:
p0 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition0_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p1 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition1_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p2 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition2_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p3 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition3_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p4 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition4_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p5 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition5_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p6 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition6_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p7 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition7_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p8 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition8_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p9 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition9_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p10 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition10_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p11 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition11_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p12 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition12_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p13 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition13_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p14 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition14_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p15 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition15_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p16 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition16_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p17 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition17_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p18 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition18_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p19 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition19_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p20 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition20_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 
p21 = arrow::read_feather(glue("{output_path}/predictions/predictions_partition21_220223.feather"), as_data_frame=TRUE) %>% select(-split) %>% filter(model==!!model) 

In [None]:
predictions = lazy_dt(bind_rows(p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11, p12, p13, p14)) 

In [None]:
base_size = 8
title_size = 10
facet_size = 10
geom_text_size=3
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
colors_dict = read_json("colors.json")
color_map <- c(
    "Identity(AgeSex)+MLP" = colors_dict$pastel$red$mid,
    "Identity(Records)+MLP" = colors_dict$pastel$red$mid,
    "GNN(Records)+MLP" = colors_dict$pastel$red$mid,
    "Identity(AgeSex+Records)+MLP" = colors_dict$pastel$red$mid,
    "GNN(AgeSex+Records)+MLP" = colors_dict$pastel$red$mid
)

In [None]:
outcome_freq = data_outcomes %>% filter(prevalent==0) %>% group_by(endpoint) %>% summarize(freq = sum(event)/n()) %>% as_tibble()
outcome_freq %>% arrange(desc(freq))

In [None]:
endpoint_map = phecode_defs$phecode_string
names(endpoint_map) =  phecode_defs$endpoint
endpoint_order_freq = (outcome_freq %>% arrange(desc(freq)))$endpoint

## Load data

In [None]:
list.dirs(path = project_path, full.names = TRUE, recursive = TRUE)

# Figure 2: Selected Endpoints

## Metabolic state and incident disease

In [None]:
pred_outcomes = predictions %>% left_join(data_outcomes, on=c(eid, endpoint)) %>% as.data.table() 

In [None]:
logh_inc = pred_outcomes %>% filter(prevalent==0) %>% group_by(endpoint, model) %>% mutate(logh_perc = ntile(logh, 100)) %>% ungroup() %>% as_tibble()

## No buffer

In [None]:
logh_T_agg = logh_inc %>% group_by(endpoint, model, logh_perc) %>% summarise(ratio = mean(event)) %>% as_tibble()

In [None]:
logh_T_agg %>% write_feather(glue("{output_path}/logh_agg_220224.feather"))

In [None]:
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_T_agg %>% filter(model=="GNN(Records)+MLP") %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

mh_events = ggplot(temp, aes(x=logh_perc, y=ratio*100, color=logh_perc)) + 
    labs(title=NULL, x="Medical History Percentile [%]", y="Observed Event Rate [%]") +
    geom_point(alpha=0.7, size=0.1) + 
    scale_colour_gradient(
  low = "#7AC6FF",
  high = "#023768",
  space = "Lab",
  na.value = "grey50",
  guide = "colourbar",
  aesthetics = "colour"
)+
    #scale_color_manual(values=c("Metabolomics"="black"))+
    scale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistoryRisk"
mh_events %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

In [None]:
plot_name = "MedicalHistoryRisk"
mh_events %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

## Add buffer

In [None]:
pred_outcomes %>% head()

In [None]:
logh_T_agg_buffer = pred_outcomes %>% filter(prevalent==0) %>% mutate(event_buffer = case_when((event!=0&time<1) ~ 0, TRUE ~ event)) %>% group_by(endpoint, model) %>% mutate(logh_perc = ntile(logh, 100)) %>% group_by(endpoint, model, logh_perc) %>% summarise(ratio = mean(event_buffer)) %>% as_tibble()

In [None]:
logh_T_agg_buffer %>% write_feather(glue("{output_path}/logh_agg_1ybuffer_220224.feather"))

In [None]:
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
#temp_rank = event_rest %>% filter(features=="Metabolomics") %>% arrange(desc(MET10PercvsREST))
endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_T_agg_buffer %>% filter(model=="GNN(Records)+MLP") %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() #%>% sample_n(10000)

mh_events = ggplot(temp, aes(x=logh_perc, y=ratio*100, color=logh_perc)) + 
    labs(title=NULL, x="Medical History Percentile [%]", y="Observed Event Rate [%]") +
    geom_point(alpha=0.7, size=0.1) + 
    scale_colour_gradient(
  low = "#7AC6FF",
  high = "#023768",
  space = "Lab",
  na.value = "grey50",
  guide = "colourbar",
  aesthetics = "colour"
)+
    #scale_color_manual(values=c("Metabolomics"="black"))+
    scale_y_continuous(limits=c(0, NA), expand=expansion(mult=c(0, .05)))+#, limits=c(0, NA))+
    scale_x_continuous(expand=expansion(add=c(0, 1)))+
    facet_wrap(~endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistoryRisk_1ybuffer"
mh_events %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

## Metabolomic State and Event Trajectories

In [None]:
logh_mh = logh_inc %>% select(endpoint, model, eid, logh_perc, event, time) %>% group_by(endpoint) %>% 
    mutate(MH=case_when(logh_perc %in% 91:100 ~ "High", 
                        logh_perc %in% 45:55 ~ "Mid", 
                        logh_perc %in% 1:10 ~ "Low",
                        TRUE ~ "NA")
          ) %>% mutate(MET = fct_rev(factor(MET, levels=c("Low", "Mid", "High")))) %>% ungroup() %>% 
    filter(MH!="NA") #%>% select(eid, endpoint, logh, logh_group)

In [None]:
require("ggquickeda")
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
met_map = c("High"="#023768", "Mid"="#4F8EC1", "Low"="#7AC6FF")

endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_mh %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() 

km_plot = ggplot(temp, aes(time = time, status = event, fill=MET, color=MET,group=MET)) +
    geom_km(trans = "event") + 
    geom_kmticks(trans = "event", size=0.3) + geom_kmband(trans = "event") + 
    labs(x="Time [Years]", y="Cumulative Events [%]")+
    scale_color_manual(values=met_map)+scale_fill_manual(values=met_map)+
    scale_y_continuous(labels = function(x) round(x*100, 1), expand=c(0, 0))+
    scale_x_continuous(expand=expansion(add=c(0, .1)), breaks=c(5, 10))+
    facet_wrap(~ endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistory_KMs"
km_plot %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

# Top 1%

In [None]:
logh_mh = logh_inc %>% select(endpoint, model, eid, logh_perc, event, time) %>% group_by(endpoint) %>% 
    mutate(MH=case_when(logh_perc == 100 ~ "High", 
                        logh_perc %in% 50:51 ~ "Mid", 
                        logh_perc == 1 ~ "Low",
                        TRUE ~ "NA")
          ) %>% mutate(MET = fct_rev(factor(MH, levels=c("Low", "Mid", "High")))) %>% ungroup() %>% 
    filter(MH!="NA") #%>% select(eid, endpoint, logh, logh_group)

In [None]:
require("ggquickeda")
plot_width = 50; plot_height=75; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)
mh_map = c("High"="#023768", "Mid"="#4F8EC1", "Low"="#7AC6FF")

endpoint_order = (phecode_defs %>% mutate(phecode_rank = as.numeric(phecode)) %>% arrange(phecode_rank) %>% as_tibble())$endpoint
endpoint_selection = endpoint_order# %>% head(500)
temp = logh_mh %>% mutate(endpoint = factor(endpoint, levels=endpoint_order)) %>% filter(endpoint %in% endpoint_selection) %>% ungroup() 

km_plot = ggplot(temp, aes(time = time, status = event, fill=MH, color=MH,group=MH)) +
    geom_km(trans = "event") + 
    geom_kmticks(trans = "event", size=0.3) + geom_kmband(trans = "event") + 
    labs(x="Time [Years]", y="Cumulative Events [%]")+
    scale_color_manual(values=met_map)+scale_fill_manual(values=mh_map)+
    scale_y_continuous(labels = function(x) round(x*100, 1), expand=c(0, 0))+
    scale_x_continuous(expand=expansion(add=c(0, .1)), breaks=c(5, 10))+
    facet_wrap(~ endpoint, scale="free", labeller = labeller(endpoint = endpoint_map), ncol=25) + theme(legend.position="none")

In [None]:
plot_name = "MedicalHistory_KMs_Top1"
km_plot %>% ggsave(filename=glue("outputs/{plot_name}.png"), device="png", width=plot_width, height=plot_height, dpi=plot_res, limitsize=FALSE)

# Figure 2

In [None]:
plot_width=8.25; plot_height=10; plot_dpi=320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_dpi)
fig2 = met_events / km_plot# +plot_annotation(tag_levels = 'A')

In [None]:
fig2

In [None]:
library(gt)
plot_name = "Figures_2_AB"
fig2 %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=320)