In [None]:
###############################################################
# Import dependencies
rm(list = ls())
library(reticulate)
sagemaker <- import('sagemaker')
pd <- import("pandas")

extra_packages_reqd <- c("aws.s3", "aws.ec2metadata")
packages_available <- library()$results[, "Package"]
for (package_current in extra_packages_reqd) {
    if (!package_current %in% packages_available) {
        install.packages(package_current)
    }
}

In [None]:
###############################################################
# Load meta data file with test/train/... splits
setwd('/home/ec2-user/SageMaker/jbc-cough-in-a-box')
dir_root <- readRDS(file.path("notebooks/matching/config.RDS"))
plot_dir <- file.path("plotting", "stratified_analysis")
dir.create(plot_dir, showWarnings = F)
Sys.setenv("AWS_DEFAULT_REGION" = "eu-west-2")
meta <- aws.s3::s3read_using(read.csv, 
             object = file.path(dir_root, "BAMstudy2022-prep/meta_data_with_splits_old_format.csv"))
matched_vars_test_set <- c("react_or_tt", 
                            "gender", 
                            "age_binned", 
                            "cough", 
                            "no_symptoms",
                            "Sore.throat",
                            "Asthma",
                            "Shortness.of.breath",
                            "Runny.or.blocked.nose")
matched_vars_train_set <- c("gender", 
                            "age_binned",
                            "cough", 
                            "Sore.throat", 
                            "Asthma",
                            "Shortness.of.breath", 
                            "Runny.or.blocked.nose", 
                            "COPD.or.Emphysema", 
                            "smoker_status")
matched_vars_curr <- matched_vars_test_set
# Create stratum labels by concatenating relevant metadata
meta$stratum <- apply(meta[, matched_vars_curr, drop = F], 1, 
                                            function(x) paste(x, collapse = "_"))
matched_vars_tidy <- c("Channel", "Gender", "Age", "Cough", "Symptoms", 
                "Sore throat", "Asthma", "Shortness of breath", "Runny or blocked nose")
names(matched_vars_tidy) <- matched_vars_curr

In [None]:
###############################################################
# Analyse classifier performance stratum-by-stratum

all_strata <- unique(meta$stratum[meta$in_matched_rebalanced_long_test | meta$in_matched_rebalanced_test])
all_matched_test_strata <- unique(meta$stratum[meta$in_matched_rebalanced_test])
strat_pvals<-strat_qvals <- data.frame(stratum = all_strata)
test_types <- c("matched_test", "matched_long_test", "matched_base_and_long_test")[3]
train_types <- c("standard_train", "matched_train")[1]
res_all <- list()
for (test_type in test_types) {
    if (test_type == "matched_test") {
        meta_matched_test <- meta[meta$in_matched_rebalanced_test, ]
    }
    if (test_type == "matched_long_test") {
        meta_matched_test <- meta[meta$in_matched_rebalanced_long_test, ]
    }
    if (test_type == "matched_base_and_long_test") {
        meta_matched_test <- meta[meta$in_matched_rebalanced_test | meta$in_matched_rebalanced_long_test, ]
    }
    # We selected a stratum sample size threshold of n = 20 to ensure 
    # some precision when estimating accuracy
    n_thresh_in_matched_test <- 20
    n_in_strata <- table(meta_matched_test$stratum)
    print(paste0("n strata with sample size > ", 
                 n_thresh_in_matched_test, " is ", sum(n_in_strata >= n_thresh_in_matched_test)))
    strata_to_analyse <- names(n_in_strata)[n_in_strata >= n_thresh_in_matched_test]

    print(test_type)
    if (test_type == "matched_test") {
        meta_curr <- meta[which(meta$in_matched_rebalanced_test), ]
        standard_train_input_files <- "ss_predicts_standard_train_matched_test.csv"
        matched_train_input_files <- "ss_predicts_matched_train_matched_test.csv"
    }
    if (test_type == "matched_long_test") {
        meta_curr <- meta[which(meta$in_matched_rebalanced_long_test), ]
        standard_train_input_files <- "ss_predicts_standard_train_long_matched_test.csv"
        matched_train_input_files <- "ss_predicts_matched_train_long_matched_test.csv"
    }
    if (test_type == "matched_base_and_long_test") {
        meta_curr <- meta[which(meta$in_matched_rebalanced_test | meta$in_matched_rebalanced_long_test), ]
        standard_train_input_files <- c("ss_predicts_standard_train_test.csv", 
                                       "ss_predicts_standard_train_long_test.csv")
        matched_train_input_files <- "ss_predicts_matched_train_test.csv"
    }
    input_files <- list(standard_train = standard_train_input_files, 
                     matched_train = matched_train_input_files)
    roc_list <- list()
    for (train_type in train_types) {
        input_files_curr <- input_files[[train_type]]
        res <- data.frame()
        for (input_file in input_files_curr) {
            res <- rbind(res, 
                         aws.s3::s3read_using(read.csv, 
                     object = file.path(dir_root, "audio_sentences_for_matching", 
                                        input_file)))
        }
        print(dim(res))
        res$stratum <- meta_curr[match(res$audio_sentence, meta_curr$audio_sentence), "stratum"]
        res <- res[!is.na(res$stratum), ]
        print(dim(meta_curr))
        print(dim(res))
        p_vals_pos<-p_vals_neg <- c()
        roc_all <- NULL
        res_by_stratum <- list()
        # For each stratum we perform a Wilcoxon test on predictions between COVID+/COVID- subgroups
        # and estimate a ROC curve for predicting COVID in that stratum
        for (stratum in strata_to_analyse) {
            res_by_stratum[[stratum]] <- list()
            res_curr <- res[which(res$stratum == stratum), ]
            res_by_stratum[[stratum]]$covid_case_pos_prob <- res_curr[res_curr$test_result == 1, "Positive"]
            res_by_stratum[[stratum]]$covid_control_pos_prob <- res_curr[res_curr$test_result == 0, "Positive"]
            wilcox_out_pos <- wilcox.test(res_by_stratum[[stratum]]$covid_case_pos_prob,
                                          res_by_stratum[[stratum]]$covid_control_pos_prob, 
                                         alternative = "greater")
            p_vals_pos[stratum] <- wilcox_out_pos$p.value
            nam_curr <- paste0(train_type, "_", test_type)
            res_all[[nam_curr]] <- res_by_stratum
            case_vals_test <- res_all[[nam_curr]][[stratum]]$covid_case_pos_prob
            control_vals_test <- res_all[[nam_curr]][[stratum]]$covid_control_pos_prob
            n_pr <- length(case_vals_test)
            covid_status <- rep(c(0, 1), each = n_pr)
            predictions <- c(control_vals_test, 
                             case_vals_test) 
                roc_out <- pROC::roc(covid_status,
                                    predictions,
                                    smoothed = TRUE,
                                    # arguments for ci
                                    ci = TRUE, ci.alpha = 0.95, stratified = FALSE,
                                    # arguments for plot
                                    plot = F, auc.polygon = TRUE, max.auc.polygon = TRUE, grid = TRUE,
                                    print.auc = TRUE, show.thres = TRUE, quiet = TRUE)
            roc_all <- rbind(roc_all, t(c(stratum, roc_out$ci)))
        }
        roc_all <- as.data.frame(roc_all)
        names(roc_all) <- c("stratum", "roc_l", "roc_m", "roc_u")
        roc_all[, matched_vars_test_set] <- meta[match(roc_all$stratum, meta$stratum), matched_vars_test_set]
        roc_list[[nam_curr]] <- roc_all
        # Adjust p-values for multiple testing
        q_vals_pos <- p.adjust(p_vals_pos, method = "BH")
        sig_strata <- names(q_vals_pos)[q_vals_pos < .05]
        strat_pvals[match(names(p_vals_pos), strat_pvals$stratum), nam_curr] <- p_vals_pos
        strat_qvals[match(names(q_vals_pos), strat_pvals$stratum), nam_curr] <- q_vals_pos
    }
}
strat_qvals[, matched_vars_test_set] <- meta[match(strat_qvals$stratum, meta$stratum), matched_vars_test_set]
strat_pvals[, matched_vars_test_set] <- meta[match(strat_pvals$stratum, meta$stratum), matched_vars_test_set]
table(strat_qvals$standard_train_matched_test < .05, strat_qvals$standard_train_matched_long_test < .05)
names(strat_qvals)

In [None]:
###############################################################
# Plot ROC-AUC confidence intervals in each stratum

roc_all <- roc_list$standard_train_matched_base_and_long_test
roc_all <- roc_all[order(roc_all$roc_m), ]
strata_plot <- roc_all$stratum
n_strata <- length(strata_plot)
pdf(file.path(plot_dir, "roc-auc_cis.pdf"), 12, 8)
par(mar = c(24, 8, 2, 3))
auc_ref <- 0.62
tick_eps <- .1
ylim_lower <- .2
shift_down <- .1
plot(roc_all$roc_m, ty = "n",
            ylim = c(ylim_lower, 1),
            xaxt = "n", 
            xlab = "",
            ylab = "ROC-AUC",
            las = 2,
            yaxs = "i",
            bty = "n")
for (stratum in strata_plot) {
    xpl <- match(stratum, strata_plot)
    par(xpd = NA)
    lines(x = rep(xpl, 2), y = c(1, -.9), col = gray(0.95), lty = 3)
    par(xpd = F)
    ypl <- roc_all[match(stratum, roc_all$stratum), c("roc_l", "roc_u")]
    col_ci <- 1#ifelse(ypl[1] > auc_ref, 2, 1)
    lines(x = rep(xpl, 2), y = ypl, col = col_ci)
    lines(x = xpl + c(-1, 1) * tick_eps, y = rep(ypl[1], 2), col = col_ci)
    lines(x = xpl + c(-1, 1) * tick_eps, y = rep(ypl[2], 2), col = col_ci)
    stratum_responses <- strsplit(stratum, split = "_")[[1]]
    stratum_responses <- gsub("No symptoms", "No", stratum_responses)
    stratum_responses <- gsub("Symptoms", "Yes", stratum_responses)
    stratum_responses <- gsub("Female", "F", stratum_responses)
    stratum_responses <- gsub("Male", "M", stratum_responses)
    stratum_responses <- gsub("No cough", "No", stratum_responses)
    stratum_responses <- gsub("Cough", "Yes", stratum_responses)
    stratum_info <- paste0(matched_vars_tidy[matched_vars_curr], " = ", stratum_responses)
    stratum_info
    names(stratum_responses) <- matched_vars_test_set
    line_curr <- 0
    for (varc in matched_vars_test_set) {
        line_curr <- line_curr + 1
        par(xpd = NA)
        text(x = xpl, y = ylim_lower - line_curr * shift_down, 
              cex = .6,
              labels = stratum_responses[varc], 
              las = 2, srt = 45,
            adj = 1)
        par(xpd = F)
    }
    par(xpd = NA)
    line_curr <- line_curr + 1
    text(x = xpl, y = ylim_lower - line_curr * .1, 
          cex = .6,
          labels = length(res_all[[nam_curr]][[stratum]]$covid_case_pos_prob) * 2, 
          las = 2, srt = 90,
        adj = 1)
    par(xpd = F)
    qval_curr <- strat_qvals$standard_train_matched_base_and_long_test[
        match(stratum, strat_qvals$stratum)]
    points(x = xpl, y = roc_all[match(stratum, roc_all$stratum), c("roc_m")],
              pch = 21, 
               bg = ifelse(qval_curr < .05, "black", "white"),
              col = 1)
}
line_curr <- 0
x_lab_at <- -5
for (varc in matched_vars_test_set) {
    line_curr <- line_curr + 1
    par(xpd = NA)
    text(x = x_lab_at, y = ylim_lower - line_curr * shift_down, 
          cex = .8,
          labels = matched_vars_tidy[varc], 
          las = 2, srt = 45,
        adj = 1)
    par(xpd = F)
}
par(xpd = NA)
line_curr <- line_curr + 1
text(x = x_lab_at, y = ylim_lower - line_curr * shift_down, 
      cex = .8,
      labels = "# in stratum", 
      las = 2, srt = 45,
    adj = 1)
par(xpd = F)
abline(h = auc_ref, col = 1, lty = 2)
abline(h = 0.5, col = 1)
legend(x = "topleft", pch = 21, pt.bg = c("white", "black", "white")[2],
      col = c("black", "black", "red")[2],
      legend = c("ROC-AUC > 0.5 (FDR < 0.05)",
                "ROC-AUC > 0.62 (p < 0.05)")[1],
      cex = .8, bg = "white")
mtext(side = 4, at = c(.5, .62), text = c("0.5", "0.62"), las = 2)
dev.off()

In [None]:
###############################################################
# Gather some numbers for the text of the paper

length(strata_plot)
length(all_strata)
n_overlap_ref_auc <- sum(roc_all[match(strata_plot, roc_all$stratum), c("roc_l")] < auc_ref &
                        roc_all[match(strata_plot, roc_all$stratum), c("roc_u")] > auc_ref)
n_overlap_ref_auc / length(strata_plot)
qvals_plot <- strat_qvals$standard_train_matched_base_and_long_test[
        match(strata_plot, strat_qvals$stratum)]
n_qval_sig <- sum(qvals_plot < .05)
n_qval_sig
mean(strat_qvals$standard_train_matched_base_and_long_test < .05, na.rm = T)