# Benchmarks

## Initialize

In [None]:
#library(Rmisc)
#library(dtplyr)
library(tidyverse)
library(glue)
library(arrow)
library(patchwork)
#library(data.table)
library("jsonlite")
library(ggthemes)

In [None]:
if (grepl("sc", Sys.info()[["nodename"]], fixed=TRUE)) {
    base_path = "/sc-projects/sc-proj-ukb-cvd"
} else {
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"}
print(base_path)

project_label="22_medical_records"
project_path = glue("{base_path}/results/projects/{project_label}")
figure_path = glue("{project_path}/figures")
output_path = glue("{project_path}/data")

experiment = "230425"
experiment_path = glue("{output_path}/{experiment}")

In [None]:
library(data.table)
library(arrow)

In [None]:
endpoints_md = fread(glue("{experiment_path}/endpoints.csv"), colClasses=c("phecode"="character"))
endpoints = sort(endpoints_md$endpoint)

In [None]:
endpoint_defs = arrow::read_feather(glue("{output_path}/phecode_defs_220306.feather")) %>% arrange(endpoint)

In [None]:
endpoint_defs %>% group_by(phecode_category) %>% tally()

In [None]:
endpoint_defs %>% filter(phecode_category %in% c("Death", "ID", "Pulmo", "Cardio")) %>% group_by(phecode_category) %>% tally()

In [None]:
data_outcomes = arrow::read_feather(glue("{output_path}/baseline_outcomes_220627.feather", as_data_frame=FALSE)) 

In [None]:
records_long = arrow::read_feather("/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/artifacts/final_records_omop_220531.feather",
                                  col_select=all_of(c("eid", "birth_date", "exit_date", "concept_id", "date", "vocabulary", "origin"))) %>% 
                filter(vocabulary=="phecode") %>% mutate(concept_id = str_replace_all(concept_id, "\\.", "\\-")) %>%
        mutate(severity = case_when(
        origin %in% c("gp_ctv3", "gp_sct") ~ "Light",
        origin %in% c("hes_icd9", "hes_icd10") ~ "Severe",
        origin == "death_records" ~ "Fatal")) %>%
    mutate(severity = factor(severity, levels=c("Light", "Severe", "Fatal")))

In [None]:
partitions = 0:21
partitions
paths = c()
for (p in partitions){
    temp_path = glue("{experiment_path}/loghs_covid/Identity(Records)+MLP/{p}/test.feather")
    paths = c(paths, temp_path)
    }

In [None]:
predictions = paths %>% map_df(
    ~suppressWarnings(read_feather(.))) %>% 
    #pivot_longer(endpoint_selection, names_to="endpoint", values_to="logh") %>%
    mutate(eid = as.integer(as.character(eid))) #%>%
    #select(endpoint, eid, logh))# %>% arrange(endpoint, eid)
predictions_nocovid = predictions %>% arrange(eid)
predictions_nocovid %>% head()

In [None]:
data_covariates = arrow::read_feather(glue("{output_path}/220627/data_covariates_full.feather"))

In [None]:
#predictions = bind_rows(predictions_recruitment %>% mutate(t0="covid", endpoints="), predictions_2020 %>% mutate(t0="covid")) %>% left_join(data_covariates)

In [None]:
predictions = predictions_nocovid %>% left_join(data_covariates)

In [None]:
endpoints_md %>% filter(str_detect(phecode_string, "Death|Pneumonia$|Sepsis$"))

In [None]:
endpoints_md %>% 
    #filter(str_detect(phecode_string, "virus")) %>% 
    filter(str_detect(phecode_category, "ID"))

In [None]:
endpoints_md %>% 
    filter(str_detect(phecode_string, "Infections")) #%>% 
    #filter(str_detect(phecode_category, "ID"))

In [None]:
t0_date = as.Date("2019-12-31")

In [None]:
library(lubridate)
data_outcomes_long = records_long %>% filter(date>t0_date) %>% filter(t0_date<exit_date) %>% 
    mutate(event=1, time=time_length(difftime(date, t0_date), "years")) %>%
    arrange(eid, concept_id, severity, time) 

In [None]:
library(survcomp)

In [None]:
library(survival)
library(survminer)
require("ggquickeda")

mh_map = c("20"="#023768", "10"="#4F8EC1", "1"="#7AC6FF")

calc_km = function(predictors, endpoint, t0, severity){
    temp_outcomes = data_outcomes_long %>% 
        filter(concept_id==!!endpoint) %>%
        filter(severity==!!severity) %>%
        select(eid, event, time) %>% group_by(eid) %>% slice_min(time, with_ties=FALSE) %>% ungroup()
    
    km_df = predictions %>% 
        filter(t0==!!t0) %>%
        select(eid, all_of(predictors)) %>% 
        mutate(across(all_of(predictors), ~ntile(., 100))) %>%
        mutate(predictor = rowSums(across(all_of(predictors)))) %>%
        left_join(temp_outcomes, by="eid") %>%
        replace_na(list(event=FALSE, time=time_length(difftime(max(data_outcomes_long$exit_date), t0_date), "years"))) %>%
        rename(logh=predictor) %>% ungroup() %>%
        mutate(logh_bin = factor(ntile(logh, 20)))
    return(km_df)
    }

get_cindex_df = function(predictors, endpoint, t0, severity){
    temp_outcomes = data_outcomes_long %>% 
        filter(concept_id==!!endpoint) %>%
        filter(severity==!!severity) %>%
        select(eid, event, time) %>% group_by(eid) %>% slice_min(time, with_ties=FALSE) %>% ungroup()
    
    km_df = predictions %>% 
        filter(t0==!!t0) %>%
        select(eid, all_of(predictors)) %>% 
        mutate(across(all_of(predictors), ~ntile(., 100))) %>%
        mutate(predictor = rowSums(across(all_of(predictors)))) %>%
        left_join(temp_outcomes, by="eid") %>%
        replace_na(list(event=FALSE, time=time_length(difftime(max(data_outcomes_long$exit_date), t0_date), "years"))) %>%
        rename(logh=predictor) %>%
        mutate(logh_bin = factor(ntile(logh, 20)))
    
    return(km_df %>% mutate(predictors = paste(predictors, collapse=", "), severity=severity))
    }

plot_km = function(predictors, endpoint, t0, severity){
    temp_outcomes = data_outcomes_long %>% 
        filter(concept_id==!!endpoint) %>%
        filter(severity==!!severity) %>%
        select(eid, event, time) %>% group_by(eid) %>% slice_min(time, with_ties=FALSE) %>% ungroup()
    
    km_df = predictions %>% 
        filter(t0==!!t0) %>%
        select(eid, all_of(predictors)) %>% 
        mutate(across(all_of(predictors), ~ntile(., 100))) %>%
        mutate(predictor = rowSums(across(all_of(predictors)))) %>%
        left_join(temp_outcomes, by="eid") %>%
        replace_na(list(event=FALSE, time=time_length(difftime(max(data_outcomes_long$exit_date), t0_date), "years"))) %>%
        rename(logh=predictor) %>%
        mutate(logh_bin = factor(ntile(logh, 20)))
    
    
    
   #print(km_df)
    cindex_df = km_df %>% sample_n(100000)
    print(severity)
    print((concordance.index(x=cindex_df$logh, surv.time=cindex_df$time, surv.event = cindex_df$event))$c.index)
    
    ratio = round((km_df %>% summarise(ratio = mean(event)))$ratio*100, 2)
    #print(ratio)
    
    
    g = ggplot(km_df %>% filter(logh_bin %in% c(1, 10, 20)), aes(time = time, status = event, fill=logh_bin, color=logh_bin, group=logh_bin)) +
    
        labs(title=glue("{severity} ({ratio}%)"), y="Cumulative Incidence (%)", x=NULL) + 
        geom_hline(yintercept=as.numeric(ratio)/100, color="black", alpha=0.5, linetype="21")+
    
        annotate("rect", xmin=0.22, xmax=0.45, ymin=0, ymax=Inf, fill="black", alpha=0.02) +
        annotate("rect", xmin=0.7, xmax=1.25, ymin=0, ymax=Inf, fill="black", alpha=0.02) +
    
        annotate("text", label="First\nWave", x=0.335, y=0.07, hjust=0.5)+#, size=3) +
        annotate("text", label="Second\nWave", x=0.99, y=0.07, hjust=0.5)+#, size=3) +
    
        geom_km(trans = "event") + 
        geom_kmticks(trans = "event", size=0.2, alpha=0.01) + 
        geom_kmband(trans = "event") + 
    
        scale_color_manual(values=mh_map)+scale_fill_manual(values=mh_map)+
    
        scale_y_continuous(labels = function(x) round(x*100, 1), expand=c(0, 0))+
        scale_x_continuous(expand=expansion(add=c(0, .1)), 
                           breaks=c(0, 0.25, 0.5, 0.75, 1, 1.25, 1.5), 
                           labels=c("Jan 20", "April 20", "July 20", "Oct 20", "Jan 21", "April 21", "July 21")) +
        coord_cartesian(xlim=c(0, 1.6), ylim=c(0,0.025))+#+
                           
        theme(legend.position="none", plot.title = element_text(hjust = 0.5))                    
    
    return(g)
    }

In [None]:
breaks=c(0, 0.25, 0.5, 0.75, 1, 1.25, 1.5)
labels=c("Jan 20", "April 20", "July 20", "Oct 20", "Jan 21", "April 21", "July 21")

#end of first wave 0.45
#end of second wave 0.99
#August 2021 1.6

In [None]:
base_size = 8
title_size = 10
facet_size = 8.5
geom_text_size=3
theme_set(theme_classic(base_size = base_size) + 
          theme(strip.background = element_blank(), plot.title=element_text(size=title_size, hjust=0), 
                strip.text.x = element_text(size = facet_size),axis.title=element_text(size=10), axis.text=element_text(size=8, color="black"),
                legend.position="bottom", axis.line = element_line(size = 0.2), axis.ticks=element_line(size=0.2)))

In [None]:
plot_width = 4.125; plot_height=6; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)

In [None]:
endpoint = "phecode_059"

predictors= c("age")
cindex_df_age = bind_rows(get_cindex_df(predictors, endpoint, "covid", "Severe"), get_cindex_df(predictors, endpoint, "covid", "Fatal"))

predictors= c("age", "OMOP_4306655", "phecode_468", "phecode_092-2")
cindex_df_agemh = bind_rows(get_cindex_df(predictors, endpoint, "covid", "Severe"), get_cindex_df(predictors, endpoint, "covid", "Fatal"))

cindex_df = bind_rows(cindex_df_age, cindex_df_agemh)

In [None]:
cindex_df %>% write_feather("outputs/covid_cindeces_230425.feather")

In [None]:
predictors= c("age")
endpoint = "phecode_059"
predictor_label = toString((endpoints_md %>% filter(endpoint %in% predictors))$phecode_string)

g1 = plot_km(predictors, endpoint, "covid", "Severe")
df_1 = calc_km(predictors, endpoint, "covid", "Severe") %>% mutate(score = "age only")
g2 = plot_km(predictors, endpoint, "covid", "Fatal")
df_2 = calc_km(predictors, endpoint, "covid", "Fatal") %>% mutate(score = "age only")

strat_age = (g1/g2) + plot_annotation(title = glue('COVID outcomes stratified by {predictors}'))
strat_age 

In [None]:
predictors= c("age", "OMOP_4306655", "phecode_468", "phecode_092-2")
endpoint = "phecode_059"
predictor_label = toString((endpoints_md %>% filter(endpoint %in% predictors))$phecode_string)

g3 = plot_km(predictors, endpoint, "covid", "Severe")
df_3 = calc_km(predictors, endpoint, "covid", "Severe") %>% mutate(score = "age+mh")
g4 = plot_km(predictors, endpoint, "covid", "Fatal")
df_4 = calc_km(predictors, endpoint, "covid", "Fatal") %>% mutate(score = "age+mh")

strat_score = (g3/g4) + plot_annotation(title = glue('COVID outcomes stratified by aggregated\npartial hazards of {predictor_label}'))
strat_score

In [None]:
plot_width = 8.25; plot_height=6; plot_res = 320
options(repr.plot.width = plot_width, repr.plot.height = plot_height, repr.plot.res=plot_res)

fig_covid = strat_age | strat_score
fig_covid

In [None]:
plot_name = "SupplFigure6d_covid"
fig_covid %>% ggsave(filename=glue("outputs/{plot_name}.pdf"), device="pdf", width=plot_width, height=plot_height, dpi=plot_res)

In [None]:
age_new = arrow::read_feather('/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/220603_medicalhistory/baseline_covariates.feather', 
                                       col_select=c("eid", "age_at_recruitment_f21022_0_0", "date_of_attending_assessment_centre_f53_0_0")) %>%
    mutate(date_covid = as.Date("2019-12-31")) %>%
    mutate(diff_covid = time_length(difftime(date_covid, date_of_attending_assessment_centre_f53_0_0), "years")) %>%
    mutate(age_covid = floor(age_at_recruitment_f21022_0_0 + diff_covid)) %>% select(eid, age_covid)

In [None]:
predictions %>% left_join(age_new) %>% select(eid, age, age_covid) %>% mutate(age_bin=ntile(age, 20)) %>% filter(age_bin==20) %>% arrange(age_covid) %>% ggdist::median_qi(age_covid)

In [None]:
calc_km_df = function(df){
    fit <- survfit(Surv(time, event) ~ logh_bin, data=df %>% filter(logh_bin==20))
    surv_summary(fit) %>% 
        mutate(cumevents = cumsum(n.event)) %>% 
        mutate(label = glue("{round((1-surv)*100, 2)}% (CI {round((1-upper)*100, 2)}%, {round((1-lower)*100, 2)}%), {cumevents} events")) %>%
        mutate(link="yes") %>% left_join(data.frame(link = c("yes", "yes", "yes"), times = c(0.45, 0.99, 1.6)), by="link") %>%
        filter(time<times) %>% group_by(times) %>% slice_max(time) %>% ungroup()
    }

In [None]:
km_df_all = bind_rows(
    calc_km_df(df_1) %>% mutate(score="age only", severity="Severe"),
    calc_km_df(df_2) %>% mutate(score="age only", severity="Fatal"),
    calc_km_df(df_3) %>% mutate(score="age+hr", severity="Severe"),
    calc_km_df(df_4) %>% mutate(score="age+hr", severity="Fatal")
    ) %>%
    select(times, score, severity, label) %>% pivot_wider(names_from="times", values_from="label") %>% arrange(desc(severity))

In [None]:
colnames(km_df_all) = c("score", "severity", "after first wave", "after second wave", "until August 21")

In [None]:
km_df_all %>% arrange(score, desc(severity))