From 853b29bd4debf2a28fc3e27625f2aea8ad867018 Mon Sep 17 00:00:00 2001 From: Glen Martin Date: Mon, 9 Oct 2023 12:55:33 +0100 Subject: [PATCH] change calibration plot format --- DESCRIPTION | 4 +- NAMESPACE | 3 + R/flex_calplot.R | 128 +++++++------ R/pred_validate.R | 177 +++++++++++++++--- R/validate_logistic.R | 68 ++++--- R/validate_survival.R | 67 ++++--- man/pred_validate.Rd | 50 ++--- .../testthat/_snaps/pred_validate_logistic.md | 2 +- .../testthat/_snaps/pred_validate_survival.md | 2 +- tests/testthat/test-pred_validate_logistic.R | 10 +- tests/testthat/test-pred_validate_survival.R | 4 +- vignettes/predRupdate.Rmd | 43 ++--- 12 files changed, 354 insertions(+), 204 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 51c0468..fca8bb3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -33,8 +33,8 @@ Imports: survival, pROC, ggplot2, - ggExtra, - rlang + rlang, + ggpubr Depends: R (>= 2.10) VignetteBuilder: knitr diff --git a/NAMESPACE b/NAMESPACE index 6677f9e..de7dbc7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,6 +3,7 @@ S3method(map_newdata,default) S3method(map_newdata,predinfo_logistic) S3method(map_newdata,predinfo_survival) +S3method(plot,predvalidate) S3method(pred_predict,default) S3method(pred_predict,predinfo_logistic) S3method(pred_predict,predinfo_survival) @@ -16,6 +17,8 @@ S3method(pred_validate,default) S3method(pred_validate,predinfo_logistic) S3method(pred_validate,predinfo_survival) S3method(print,predinfo) +S3method(print,predvalidate_logistic) +S3method(print,predvalidate_survival) S3method(summary,predSR) S3method(summary,predUpdate) S3method(summary,predinfo) diff --git a/R/flex_calplot.R b/R/flex_calplot.R index e1a3c5c..82b233d 100644 --- a/R/flex_calplot.R +++ b/R/flex_calplot.R @@ -7,6 +7,7 @@ flex_calplot <- function(model_type = c("logistic", "survival"), ylim, xlab, ylab, + pred_rug, time_horizon = NULL) { model_type <- as.character(match.arg(model_type)) @@ -25,39 +26,39 @@ flex_calplot <- function(model_type = c("logistic", "survival"), spline_model <- stats::glm(ObservedOutcome ~ splines::ns(LP, df = 3), family = stats::binomial(link = "logit")) spline_preds <- stats::predict(spline_model, type = "response", se = T) - plot_df <- data.frame("p" = Prob, + plot_df <- data.frame("ObservedOutcome" = ObservedOutcome, + "p" = Prob, "o" = spline_preds$fit) - print(ggExtra::ggMarginal(ggplot2::ggplot(plot_df, - ggplot2::aes(x = .data$p, - y = .data$o)) + - ggplot2::geom_line(ggplot2::aes(linetype = "Calibration Curve", - colour = "Calibration Curve")) + - ggplot2::xlim(xlim) + - ggplot2::ylim(ylim) + - ggplot2::xlab(xlab) + - ggplot2::ylab(ylab) + - ggplot2::geom_abline(ggplot2::aes(intercept = 0, slope = 1, - linetype = "Reference", - colour = "Reference"), - show.legend = FALSE) + - ggplot2::geom_point(alpha = 0) + - ggplot2::coord_fixed() + - ggplot2::theme_bw(base_size = 12) + - ggplot2::labs(color = "Guide name", linetype = "Guide name") + - ggplot2::scale_linetype_manual(values = c("dashed", - "solid"), - breaks = c("Reference", - "Calibration Curve"), - labels = c("Reference", - "Calibration Curve")) + - ggplot2::scale_colour_manual(values = c("black", - "blue"), - breaks = c("Reference", - "Calibration Curve")) + - ggplot2::theme(legend.title=ggplot2::element_blank()), - type = "histogram", - margins = "x")) + calplot <- ggplot2::ggplot(plot_df, + ggplot2::aes(x = .data$p, + y = .data$o)) + + ggplot2::geom_line(ggplot2::aes(linetype = "Calibration Curve", + colour = "Calibration Curve")) + + ggplot2::xlim(xlim) + + ggplot2::ylim(ylim) + + ggplot2::xlab(xlab) + + ggplot2::ylab(ylab) + + ggplot2::geom_abline(ggplot2::aes(intercept = 0, slope = 1, + linetype = "Reference", + colour = "Reference"), + show.legend = FALSE) + + ggplot2::geom_point(alpha = 0) + + ggplot2::coord_fixed() + + ggplot2::theme_bw(base_size = 12) + + ggplot2::labs(color = "Guide name", linetype = "Guide name") + + ggplot2::scale_linetype_manual(values = c("dashed", + "solid"), + breaks = c("Reference", + "Calibration Curve"), + labels = c("Reference", + "Calibration Curve")) + + ggplot2::scale_colour_manual(values = c("black", + "blue"), + breaks = c("Reference", + "Calibration Curve")) + + ggplot2::theme(legend.title=ggplot2::element_blank(), + legend.position = "top") } else { cloglog <- log(-log(1 - Prob)) @@ -70,36 +71,41 @@ flex_calplot <- function(model_type = c("logistic", "survival"), bh <- survival::basehaz(vcal) plot_df$observed_risk <- 1 - (exp(-bh[(max(which(bh[,2] <= time_horizon))),1])^(exp(stats::predict(vcal, type = "lp")))) - print(ggExtra::ggMarginal(ggplot2::ggplot(plot_df, - ggplot2::aes(x = .data$Prob, - y = .data$observed_risk)) + - ggplot2::geom_line(ggplot2::aes(linetype = "Calibration Curve", - colour = "Calibration Curve")) + - ggplot2::xlim(xlim) + - ggplot2::ylim(ylim) + - ggplot2::xlab(xlab) + - ggplot2::ylab(ylab) + - ggplot2::geom_abline(ggplot2::aes(intercept = 0, slope = 1, - linetype = "Reference", - colour = "Reference"), - show.legend = FALSE) + - ggplot2::geom_point(alpha = 0) + - ggplot2::coord_fixed() + - ggplot2::theme_bw(base_size = 12) + - ggplot2::labs(color = "Guide name", linetype = "Guide name") + - ggplot2::scale_linetype_manual(values = c("dashed", - "solid"), - breaks = c("Reference", - "Calibration Curve"), - labels = c("Reference", - "Calibration Curve")) + - ggplot2::scale_colour_manual(values = c("black", - "blue"), - breaks = c("Reference", - "Calibration Curve")) + - ggplot2::theme(legend.title=ggplot2::element_blank()), - type = "histogram", - margins = "x")) + calplot <- ggplot2::ggplot(plot_df, + ggplot2::aes(x = .data$Prob, + y = .data$observed_risk)) + + ggplot2::geom_line(ggplot2::aes(linetype = "Calibration Curve", + colour = "Calibration Curve")) + + ggplot2::xlim(xlim) + + ggplot2::ylim(ylim) + + ggplot2::xlab(xlab) + + ggplot2::ylab(ylab) + + ggplot2::geom_abline(ggplot2::aes(intercept = 0, slope = 1, + linetype = "Reference", + colour = "Reference"), + show.legend = FALSE) + + ggplot2::geom_point(alpha = 0) + + ggplot2::coord_fixed() + + ggplot2::theme_bw(base_size = 12) + + ggplot2::labs(color = "Guide name", linetype = "Guide name") + + ggplot2::scale_linetype_manual(values = c("dashed", + "solid"), + breaks = c("Reference", + "Calibration Curve"), + labels = c("Reference", + "Calibration Curve")) + + ggplot2::scale_colour_manual(values = c("black", + "blue"), + breaks = c("Reference", + "Calibration Curve")) + + ggplot2::theme(legend.title=ggplot2::element_blank(), + legend.position = "top") } + if(pred_rug == TRUE){ + calplot <- calplot + + ggplot2::geom_rug(sides="b", alpha = 0.2) + } + + return(calplot) } diff --git a/R/pred_validate.R b/R/pred_validate.R index 203bc67..a97b4b9 100644 --- a/R/pred_validate.R +++ b/R/pred_validate.R @@ -58,14 +58,15 @@ #' the observed risk and the predicted risks, across the full risk range) and #' discrimination (ability of the model to distinguish between those who #' develop the outcome and those who do not) are calculated. For calibration, -#' calibration-in-the-large (CITL) and calibration slopes are estimated. CITL -#' is estimated by fitting a logistic regression model to the observed binary -#' outcomes, with the linear predictor of the model as an offset. For -#' calibration slope, a logistic regression model is fit to the observed -#' binary outcome with the linear predictor from the model as the only -#' covariate. For discrimination, the function estimates the area under the -#' receiver operating characteristic curve (AUC). Various other metrics are -#' also calculated to assess overall accuracy (Brier score, Cox-Snell R2). +#' the observed-to-expected ratio, calibration intercept and calibration +#' slopes are estimated. The calibration intercept is estimated by fitting a +#' logistic regression model to the observed binary outcomes, with the linear +#' predictor of the model as an offset. For calibration slope, a logistic +#' regression model is fit to the observed binary outcome with the linear +#' predictor from the model as the only covariate. For discrimination, the +#' function estimates the area under the receiver operating characteristic +#' curve (AUC). Various other metrics are also calculated to assess overall +#' accuracy (Brier score, Cox-Snell R2). #' #' In the case of validating a survival prediction model, this function #' assesses the predictive performance of the linear predictor and @@ -84,23 +85,29 @@ #' (TRUE), or not (FALSE). The calibration plot is produced by regressing the #' observed outcomes against a cubic spline of the logit of predicted risks #' (for a logistic model) or the complementary log-log of the predicted risks -#' (for a survival model). A histogram of the predicted risk distribution is -#' displayed on the top x-axis. Users can specify parameters to modify the +#' (for a survival model). Users can specify parameters to modify the #' calibration plot. Specifically, one can specify: \code{xlab}, \code{ylab}, #' \code{xlim}, and \code{ylim} to change plotting characteristics for the -#' calibration plot. +#' calibration plot. A rug can be added to the x-axis of the plot by setting +#' \code{pred_rug} as TRUE; this can be used to show the predicted risk +#' distribution by outcome status. #' #' @return \code{\link{pred_validate}} returns an object of class #' "\code{predvalidate}", with child classes per \code{model_type}. This is a #' list of performance metrics, estimated by applying the existing prediction #' model to the new_data. An object of class "\code{predvalidate}" is a list #' containing relevant calibration and discrimination measures. For logistic -#' regression models, this will include calibration-intercept, calibration -#' slope, area under the ROC curve, R-squared, and Brier Score. For survival -#' models, this will include observed:expected ratio (if \code{cum_hazard} is -#' provided to \code{x}), calibration slope, and Harrell's C-statistic. -#' Optionally, a flexible calibration plot is also produced, along with a -#' histogram of the predicted risk distribution. +#' regression models, this will include observed:expected ratio, +#' calibration-intercept, calibration slope, area under the ROC curve, +#' R-squared, and Brier Score. For survival models, this will include +#' observed:expected ratio (if \code{cum_hazard} is provided to \code{x}), +#' calibration slope, and Harrell's C-statistic. Optionally, a flexible +#' calibration plot is also produced, along with a box-plot and violin plot of +#' the predicted risk distribution. +#' +#' The \code{summary} function can be used to extract and print summary +#' performance results (calibration and discrimination metrics). The graphical +#' assessments of performance can be extracted using \code{plot}. #' #' @export #' @@ -109,10 +116,11 @@ #' # an example dataset within the package #' model1 <- pred_input_info(model_type = "logistic", #' model_info = SYNPM$Existing_logistic_models) -#' pred_validate(x = model1, -#' new_data = SYNPM$ValidationData, -#' binary_outcome = "Y", -#' cal_plot = FALSE) +#' val_results <- pred_validate(x = model1, +#' new_data = SYNPM$ValidationData, +#' binary_outcome = "Y", +#' cal_plot = FALSE) +#' summary(val_results) #' #' @seealso \code{\link{pred_input_info}} pred_validate <- function(x, @@ -244,6 +252,49 @@ pred_validate.predinfo_survival <- function(x, } +#' @export +print.predvalidate_logistic <- function(x, ...) { + if(x$M == 1){ + print(list("OE_ratio" = x$OE_ratio, + "OE_ratio_SE" = x$OE_ratio_SE, + "CalInt" = x$CalInt, + "CalInt_SE" = x$CalInt_SE, + "CalSlope" = x$CalSlope, + "CalSlope_SE" = x$CalSlope_SE, + "AUC" = x$AUC, + "AUC_SE" = x$AUC_SE, + "R2_CoxSnell" = x$R2_CoxSnell, + "R2_Nagelkerke" = x$R2_Nagelkerke, + "BrierScore" = x$BrierScore)) + + if(!is.null(x$PR_dist)) { + print(x$PR_dist)} + if(!is.null(x$flex_calibrationplot)) { + print(x$flex_calibrationplot)} + } else{ + for(m in 1:x$M) { + cat(paste("\nPerformance Results for Model", m, "\n", sep = " ")) + cat("================================= \n") + print(list("OE_ratio" = x[[m]]$OE_ratio, + "OE_ratio_SE" = x[[m]]$OE_ratio_SE, + "CalInt" = x[[m]]$CalInt, + "CalInt_SE" = x[[m]]$CalInt_SE, + "CalSlope" = x[[m]]$CalSlope, + "CalSlope_SE" = x[[m]]$CalSlope_SE, + "AUC" = x[[m]]$AUC, + "AUC_SE" = x[[m]]$AUC_SE, + "R2_CoxSnell" = x[[m]]$R2_CoxSnell, + "R2_Nagelkerke" = x[[m]]$R2_Nagelkerke, + "BrierScore" = x[[m]]$BrierScore)) + + if(!is.null(x[[m]]$PR_dist)) { + print(x[[m]]$PR_dist)} + if(!is.null(x[[m]]$flex_calibrationplot)) { + print(x[[m]]$flex_calibrationplot)} + } + } +} + #' @export summary.predvalidate_logistic <- function(object, ...) { @@ -261,12 +312,45 @@ summary.predvalidate_logistic <- function(object, ...) { } +#' @export +print.predvalidate_survival <- function(x, ...) { + if(x$M == 1){ + print(list("OE_ratio" = x$OE_ratio, + "OE_ratio_SE" = x$OE_ratio_SE, + "CalSlope" = x$CalSlope, + "CalSlope_SE" = x$CalSlope_SE, + "harrell_C" = x$harrell_C, + "harrell_C_SE" = x$harrell_C_SE)) + if(!is.null(x$PR_dist)) { + print(x$PR_dist)} + if(!is.null(x$flex_calibrationplot)) { + print(x$flex_calibrationplot)} + } else{ + for(m in 1:x$M) { + cat(paste("\nPerformance Results for Model", m, "\n", sep = " ")) + cat("================================= \n") + print(list("OE_ratio" = x[[m]]$OE_ratio, + "OE_ratio_SE" = x[[m]]$OE_ratio_SE, + "CalSlope" = x[[m]]$CalSlope, + "CalSlope_SE" = x[[m]]$CalSlope_SE, + "harrell_C" = x[[m]]$harrell_C, + "harrell_C_SE" = x[[m]]$harrell_C_SE)) + if(!is.null(x[[m]]$PR_dist)) { + print(x[[m]]$PR_dist)} + if(!is.null(x[[m]]$flex_calibrationplot)) { + print(x[[m]]$flex_calibrationplot)} + } + } +} + + #' @export summary.predvalidate_survival <- function(object, ...) { if(object$M == 1){ predvalidatesummary.fnc(object = object, model_type = "survival") + } else{ for(m in 1:object$M) { cat(paste("\nPerformance Results for Model", m, "\n", sep = " ")) @@ -274,10 +358,46 @@ summary.predvalidate_survival <- function(object, ...) { predvalidatesummary.fnc(object = object[[m]], model_type = "survival") } + } } +#' @export +plot.predvalidate <- function(x, ...) { + + if (x$M == 1){ + if(!is.null(x$PR_dist) & !is.null(x$flex_calibrationplot)) { + print(ggpubr::ggarrange(x$PR_dist, + x$flex_calibrationplot, + nrow = 1, + ncol = 2)) + } else if(!is.null(x$PR_dist)) { + print(x$PR_dist) + } else if(!is.null(x$flex_calibrationplot)) { + print(x$flex_calibrationplot) + } else{ + cat("No plots to print; re-run pred_validate with plotting options") + } + } else { + for (m in 1:x$M) { + if(!is.null(x[[m]]$PR_dist) & !is.null(x[[m]]$flex_calibrationplot)) { + print(ggpubr::ggarrange(x[[m]]$PR_dist, + x[[m]]$flex_calibrationplot, + nrow = 1, + ncol = 2)) + } else if(!is.null(x[[m]]$PR_dist)) { + print(x[[m]]$PR_dist) + } else if(!is.null(x[[m]]$flex_calibrationplot)) { + print(x[[m]]$flex_calibrationplot) + } else{ + cat("No plots to print; re-run pred_validate with plotting options") + } + } + } + +} + predvalidatesummary.fnc <- function(object, model_type) { if(model_type == "logistic") { @@ -296,10 +416,10 @@ predvalidatesummary.fnc <- function(object, model_type) { round(object$OE_ratio_SE, 4), round((object$OE_ratio * exp(-stats::qnorm(0.975) * object$OE_ratio_SE)), 4), round((object$OE_ratio * exp(stats::qnorm(0.975) * object$OE_ratio_SE)), 4)) - results[2,] <- c(round(object$CITL, 4), - round(object$CITL_SE, 4), - round((object$CITL - (stats::qnorm(0.975)*object$CITL_SE)), 4), - round((object$CITL + (stats::qnorm(0.975)*object$CITL_SE)), 4)) + results[2,] <- c(round(object$CalInt, 4), + round(object$CalInt_SE, 4), + round((object$CalInt - (stats::qnorm(0.975)*object$CalInt_SE)), 4), + round((object$CalInt + (stats::qnorm(0.975)*object$CalInt_SE)), 4)) results[3,] <- c(round(object$CalSlope, 4), round(object$CalSlope_SE, 4), round((object$CalSlope - (stats::qnorm(0.975)*object$CalSlope_SE)), 4), @@ -326,7 +446,7 @@ predvalidatesummary.fnc <- function(object, model_type) { cat("Nagelkerke R-squared: ", round(object$R2_Nagelkerke, 4), "\n", sep = "") cat("Brier Score: ", round(object$BrierScore, 4), "\n", sep = "") - cat("\n Also examine the histogram of predicted risks. \n") + cat("\n Also examine the distribution plot of predicted risks. \n") } else if(model_type == "survival"){ @@ -337,7 +457,8 @@ predvalidatesummary.fnc <- function(object, model_type) { "Std. Err", "Lower 95% Confidence Interval", "Upper 95% Confidence Interval") - rownames(results) <- c("Observed:Expected Ratio", "Calibration Slope") + rownames(results) <- c("Observed:Expected Ratio", + "Calibration Slope") results[1,] <- c(round(object$OE_ratio, 4), round(object$OE_ratio_SE, 4), round((object$OE_ratio * exp(-stats::qnorm(0.975) * object$OE_ratio_SE)), 4), @@ -362,7 +483,7 @@ predvalidatesummary.fnc <- function(object, model_type) { round((object$harrell_C + (stats::qnorm(0.975)*object$harrell_C_SE)), 4)) print(results) - cat("\n Also examine the histogram of predicted risks. \n") + cat("\n Also examine the distribution plot of predicted risks. \n") } } diff --git a/R/validate_logistic.R b/R/validate_logistic.R index c71e85f..4f23208 100644 --- a/R/validate_logistic.R +++ b/R/validate_logistic.R @@ -6,7 +6,8 @@ validate_logistic <- function(ObservedOutcome, xlab = "Predicted Probability", ylab = "Observed Probability", xlim = c(0,1), - ylim = c(0,1)) { + ylim = c(0,1), + pred_rug = TRUE) { # Test for 0 and 1 probabilities n_inf <- sum(is.infinite(LP)) @@ -27,8 +28,8 @@ validate_logistic <- function(ObservedOutcome, CITL_mod <- stats::glm(ObservedOutcome ~ 1, family = stats::binomial(link = "logit"), offset = LP) - CITL <- as.numeric(stats::coef(CITL_mod)[1]) - CITLSE <- sqrt(stats::vcov(CITL_mod)[1,1]) + CalInt <- as.numeric(stats::coef(CITL_mod)[1]) + CalIntSE <- sqrt(stats::vcov(CITL_mod)[1,1]) #Estimate calibration slope @@ -64,46 +65,53 @@ validate_logistic <- function(ObservedOutcome, #Brier Score BrierScore <- 1/N * (sum((Prob - ObservedOutcome)^2)) - # If not creating a calibration plot, then at least produce histogram of - # predicted risks; otherwise this is embedded into the calibration plot - if (cal_plot == FALSE){ - plot_df <- data.frame("Prob" = Prob) - print(ggplot2::ggplot(plot_df, - ggplot2::aes(x = .data$Prob)) + - ggplot2::geom_histogram(bins = 30, - colour = "black") + - ggplot2::ggtitle("Histogram of the Probability Distribution") + - ggplot2::xlab(xlab) + - ggplot2::theme_bw(base_size = 12)) - - } else{ - + #Distribution of predicted risks: + plot_df <- data.frame("Prob" = Prob, + "Outcome" = factor(ifelse(ObservedOutcome == 1, + "Event", + "No Event"))) + PR_dist <- ggplot2::ggplot(plot_df, + ggplot2::aes(x = .data$Outcome, + y = .data$Prob)) + + ggplot2::geom_violin(position = ggplot2::position_dodge(width = .75), + linewidth = 1) + + ggplot2::geom_boxplot(width = 0.1, + outlier.shape = NA) + + ggplot2::ylab(xlab) + + ggplot2::theme_bw(base_size = 12) + + # Create flexible calibration plot: + if (cal_plot == TRUE){ if(length(unique(Prob)) <= 10) { #allows handling of intercept-only models stop("Very low unique predicted risks - calplot not possible; call again with cal_plot = FALSE") } else{ - flex_calplot(model_type = "logistic", - ObservedOutcome = ObservedOutcome, - Prob = Prob, - LP = LP, - xlim = xlim, - ylim = ylim, - xlab = xlab, - ylab = ylab) + flex_calibrationplot <- flex_calplot(model_type = "logistic", + ObservedOutcome = ObservedOutcome, + Prob = Prob, + LP = LP, + xlim = xlim, + ylim = ylim, + xlab = xlab, + ylab = ylab, + pred_rug = pred_rug) } - + } else { + flex_calibrationplot <- NULL } #Return results out <- list("OE_ratio" = OE_ratio, "OE_ratio_SE" = OE_ratio_SE, - "CITL" = CITL, - "CITL_SE" = CITLSE, + "CalInt" = CalInt, + "CalInt_SE" = CalIntSE, "CalSlope" = CalSlope, "CalSlope_SE" = CalSlopeSE, "AUC" = AUC, "AUC_SE" = AUCSE, "R2_CoxSnell" = R2_coxsnell, "R2_Nagelkerke" = R2_Nagelkerke, - "BrierScore" = BrierScore) - out + "BrierScore" = BrierScore, + "PR_dist" = PR_dist, + "flex_calibrationplot" = flex_calibrationplot) + return(out) } diff --git a/R/validate_survival.R b/R/validate_survival.R index 4c1c7e8..75d513e 100644 --- a/R/validate_survival.R +++ b/R/validate_survival.R @@ -7,7 +7,8 @@ validate_survival <- function(ObservedOutcome, xlab = "Predicted Probability", ylab = "Observed Probability", xlim = c(0,1), - ylim = c(0,1)) { + ylim = c(0,1), + pred_rug = TRUE) { # Test if max observed survival time in validation data is less than # time_horizon that performance metrics as requested for: @@ -45,6 +46,8 @@ validate_survival <- function(ObservedOutcome, OE_ratio <- NA OE_ratio_SE <- NA + PR_dist <- NULL + flex_calibrationplot <- NULL } else{ #Estimate calibration-in-the-large: observed-expected ratio @@ -53,34 +56,42 @@ validate_survival <- function(ObservedOutcome, OE_ratio <- (1 - KM_observed$surv) / mean(Prob) OE_ratio_SE <- sqrt(1 / KM_observed$n.event) - # If not creating a calibration plot, then at least produce histogram of - # predicted risks; otherwise this is embedded into the calibration plot - if (cal_plot == FALSE){ - plot_df <- data.frame("Prob" = Prob) - print(ggplot2::ggplot(plot_df, - ggplot2::aes(x = .data$Prob)) + - ggplot2::geom_histogram(bins = 30, - colour = "black") + - ggplot2::ggtitle("Histogram of the Probability Distribution") + - ggplot2::xlab(xlab) + - ggplot2::theme_bw(base_size = 12)) + #Distribution of predicted risks: + plot_df <- data.frame("Prob" = Prob, + "Outcome" = factor(ifelse(ObservedOutcome[,2] == 1 & + ObservedOutcome[,1] <= time_horizon, + paste("Event prior to time", + time_horizon, sep = " "), + paste("No Event/Censored prior to time", + time_horizon, sep = " ")))) + PR_dist <- ggplot2::ggplot(plot_df, + ggplot2::aes(x = .data$Outcome, + y = .data$Prob)) + + ggplot2::geom_violin(position = ggplot2::position_dodge(width = .75), + linewidth = 1) + + ggplot2::geom_boxplot(width = 0.1, + outlier.shape = NA) + + ggplot2::ylab(xlab) + + ggplot2::theme_bw(base_size = 12) - } else{ - - if(length(unique(Prob)) <= 10) { + # Create flexible calibration plot: + if (cal_plot == TRUE){ + if(length(unique(Prob)) <= 10) { #allows handling of intercept-only models stop("Very low unique predicted risks - calplot not possible; call again with cal_plot = FALSE") } else{ - flex_calplot(model_type = "survival", - ObservedOutcome = ObservedOutcome, - Prob = Prob, - LP = LP, - xlim = xlim, - ylim = ylim, - xlab = xlab, - ylab = ylab, - time_horizon = time_horizon) + flex_calibrationplot <- flex_calplot(model_type = "survival", + ObservedOutcome = ObservedOutcome, + Prob = Prob, + LP = LP, + xlim = xlim, + ylim = ylim, + xlab = xlab, + ylab = ylab, + pred_rug = pred_rug, + time_horizon = time_horizon) } - + } else { + flex_calibrationplot <- NULL } } @@ -90,6 +101,8 @@ validate_survival <- function(ObservedOutcome, "CalSlope" = CalSlope, "CalSlope_SE" = CalSlopeSE, "harrell_C" = harrell_C_est, - "harrell_C_SE" = harrell_C_SE) - out + "harrell_C_SE" = harrell_C_SE, + "PR_dist" = PR_dist, + "flex_calibrationplot" = flex_calibrationplot) + return(out) } diff --git a/man/pred_validate.Rd b/man/pred_validate.Rd index 72baa89..3128c80 100644 --- a/man/pred_validate.Rd +++ b/man/pred_validate.Rd @@ -53,12 +53,17 @@ below.} list of performance metrics, estimated by applying the existing prediction model to the new_data. An object of class "\code{predvalidate}" is a list containing relevant calibration and discrimination measures. For logistic -regression models, this will include calibration-intercept, calibration -slope, area under the ROC curve, R-squared, and Brier Score. For survival -models, this will include observed:expected ratio (if \code{cum_hazard} is -provided to \code{x}), calibration slope, and Harrell's C-statistic. -Optionally, a flexible calibration plot is also produced, along with a -histogram of the predicted risk distribution. +regression models, this will include observed:expected ratio, +calibration-intercept, calibration slope, area under the ROC curve, +R-squared, and Brier Score. For survival models, this will include +observed:expected ratio (if \code{cum_hazard} is provided to \code{x}), +calibration slope, and Harrell's C-statistic. Optionally, a flexible +calibration plot is also produced, along with a box-plot and violin plot of +the predicted risk distribution. + +The \code{summary} function can be used to extract and print summary +performance results (calibration and discrimination metrics). The graphical +assessments of performance can be extracted using \code{plot}. } \description{ Validate an existing prediction model, to calculate the predictive @@ -96,14 +101,15 @@ observed binary outcome. Various metrics of calibration (agreement between the observed risk and the predicted risks, across the full risk range) and discrimination (ability of the model to distinguish between those who develop the outcome and those who do not) are calculated. For calibration, -calibration-in-the-large (CITL) and calibration slopes are estimated. CITL -is estimated by fitting a logistic regression model to the observed binary -outcomes, with the linear predictor of the model as an offset. For -calibration slope, a logistic regression model is fit to the observed -binary outcome with the linear predictor from the model as the only -covariate. For discrimination, the function estimates the area under the -receiver operating characteristic curve (AUC). Various other metrics are -also calculated to assess overall accuracy (Brier score, Cox-Snell R2). +the observed-to-expected ratio, calibration intercept and calibration +slopes are estimated. The calibration intercept is estimated by fitting a +logistic regression model to the observed binary outcomes, with the linear +predictor of the model as an offset. For calibration slope, a logistic +regression model is fit to the observed binary outcome with the linear +predictor from the model as the only covariate. For discrimination, the +function estimates the area under the receiver operating characteristic +curve (AUC). Various other metrics are also calculated to assess overall +accuracy (Brier score, Cox-Snell R2). In the case of validating a survival prediction model, this function assesses the predictive performance of the linear predictor and @@ -122,21 +128,23 @@ models, the cumulative baseline hazard must be available in the (TRUE), or not (FALSE). The calibration plot is produced by regressing the observed outcomes against a cubic spline of the logit of predicted risks (for a logistic model) or the complementary log-log of the predicted risks -(for a survival model). A histogram of the predicted risk distribution is -displayed on the top x-axis. Users can specify parameters to modify the +(for a survival model). Users can specify parameters to modify the calibration plot. Specifically, one can specify: \code{xlab}, \code{ylab}, \code{xlim}, and \code{ylim} to change plotting characteristics for the -calibration plot. +calibration plot. A rug can be added to the x-axis of the plot by setting +\code{pred_rug} as TRUE; this can be used to show the predicted risk +distribution by outcome status. } \examples{ #Example 1 - multiple existing model, with outcome specified; uses # an example dataset within the package model1 <- pred_input_info(model_type = "logistic", model_info = SYNPM$Existing_logistic_models) -pred_validate(x = model1, - new_data = SYNPM$ValidationData, - binary_outcome = "Y", - cal_plot = FALSE) +val_results <- pred_validate(x = model1, + new_data = SYNPM$ValidationData, + binary_outcome = "Y", + cal_plot = FALSE) +summary(val_results) } \seealso{ diff --git a/tests/testthat/_snaps/pred_validate_logistic.md b/tests/testthat/_snaps/pred_validate_logistic.md index a4839a4..99e0c2b 100644 --- a/tests/testthat/_snaps/pred_validate_logistic.md +++ b/tests/testthat/_snaps/pred_validate_logistic.md @@ -30,5 +30,5 @@ Nagelkerke R-squared: -0.0863 Brier Score: 0.1249 - Also examine the histogram of predicted risks. + Also examine the distribution plot of predicted risks. diff --git a/tests/testthat/_snaps/pred_validate_survival.md b/tests/testthat/_snaps/pred_validate_survival.md index 8e632a1..e7e07b0 100644 --- a/tests/testthat/_snaps/pred_validate_survival.md +++ b/tests/testthat/_snaps/pred_validate_survival.md @@ -21,5 +21,5 @@ Upper 95% Confidence Interval Harrell C 0.5932 - Also examine the histogram of predicted risks. + Also examine the distribution plot of predicted risks. diff --git a/tests/testthat/test-pred_validate_logistic.R b/tests/testthat/test-pred_validate_logistic.R index c142616..18eae05 100644 --- a/tests/testthat/test-pred_validate_logistic.R +++ b/tests/testthat/test-pred_validate_logistic.R @@ -15,9 +15,10 @@ test_that("output of pred_validate is as expected - single models", { expect_s3_class(val_results, c("predvalidate_logistic", "predvalidate")) expect_type(val_results, type = "list") expect_equal(names(val_results), - c("OE_ratio", "OE_ratio_SE", "CITL", "CITL_SE", "CalSlope", + c("OE_ratio", "OE_ratio_SE", "CalInt", "CalInt_SE", "CalSlope", "CalSlope_SE", "AUC", "AUC_SE", "R2_CoxSnell", - "R2_Nagelkerke", "BrierScore", "M")) + "R2_Nagelkerke", "BrierScore", + "PR_dist", "flex_calibrationplot", "M")) expect_snapshot(summary(val_results)) @@ -49,8 +50,9 @@ test_that("output of pred_validate is as expected - multiple models", { for(m in 1:model2$M) { expect_type(val_results[[m]], type = "list") expect_equal(names(val_results[[m]]), - c("OE_ratio", "OE_ratio_SE", "CITL", "CITL_SE", "CalSlope", + c("OE_ratio", "OE_ratio_SE", "CalInt", "CalInt_SE", "CalSlope", "CalSlope_SE", "AUC", "AUC_SE", "R2_CoxSnell", - "R2_Nagelkerke", "BrierScore")) + "R2_Nagelkerke", "BrierScore", + "PR_dist", "flex_calibrationplot")) } }) diff --git a/tests/testthat/test-pred_validate_survival.R b/tests/testthat/test-pred_validate_survival.R index b89558f..df49b5d 100644 --- a/tests/testthat/test-pred_validate_survival.R +++ b/tests/testthat/test-pred_validate_survival.R @@ -12,7 +12,7 @@ test_that("output of pred_validate is as expected - single models", { expect_type(val_results, type = "list") expect_equal(names(val_results), c("OE_ratio", "OE_ratio_SE", "CalSlope", "CalSlope_SE", "harrell_C", - "harrell_C_SE", "M")) + "harrell_C_SE", "PR_dist", "flex_calibrationplot", "M")) expect_snapshot(summary(val_results)) }) @@ -40,6 +40,6 @@ test_that("output of pred_validate is as expected - multiple models", { expect_type(val_results[[m]], type = "list") expect_equal(names(val_results[[m]]), c("OE_ratio", "OE_ratio_SE", "CalSlope", "CalSlope_SE", "harrell_C", - "harrell_C_SE")) + "harrell_C_SE", "PR_dist", "flex_calibrationplot")) } }) diff --git a/vignettes/predRupdate.Rmd b/vignettes/predRupdate.Rmd index 9ccfc4b..c6ba0e4 100644 --- a/vignettes/predRupdate.Rmd +++ b/vignettes/predRupdate.Rmd @@ -64,7 +64,18 @@ validation_results <- pred_validate(x = Existing_Logistic_Model, summary(validation_results) #use summary() to obtain a tidy output summary of the model performance ``` -This produces a flexible calibration plot, along with outputting various metrics of model calibration (e.g., calibration intercept and slope), discrimination (e.g., area under the ROC curve) and overall performance (e.g., R-squared). We can see that this model has poor calibration (calibration plot deviating from the y=x line, with calibration intercept and slope significantly different from 0 and 1, respectively), and poor discrimination. One may wish to update this model to the new dataset - see Example 2 below. +This produces an output of the various metrics of model calibration (e.g., calibration intercept and slope), discrimination (e.g., area under the ROC curve) and overall performance (e.g., R-squared). We can see that this model has poor calibration (calibration intercept and slope significantly different from 0 and 1, respectively), and poor discrimination. We can also obtain the flexible calibration plot, as + +```{r, fig.height=6, fig.width=6} +validation_results$flex_calibrationplot +``` + +The left-panel shows a box-plot and violin plot of the probability distributions, stratified by outcome, and the right-panel shows the flexible calibration plot. The package returns these plots as ggplot2 objects, so further modification of the plots can be made using ggplot2 statements. For example, one can change the theme of the plot as: +```{r, fig.height=6, fig.width=6} +validation_results$flex_calibrationplot + ggplot2::theme_classic() +``` + +One may wish to update this model to the new dataset - see Example 2 below. ## Survival analysis model The above example considered the validation of an existing CPM that was based on logistic regression. __predRupdate__ also contains functionality to validate CPMs that are based on time-to-event (survival) models (e.g. a Cox proportional hazards model). In such a case, the baseline cumulative hazard of the model should also be specified, along with the regression coefficients. @@ -107,38 +118,16 @@ validation_results <- pred_validate(x = Existing_TTE_Model, summary(validation_results) ``` -Here, we see that the existing model under-predicts the mean risk at 5-months (observed:Expected Ratio greater than 1), and there is evidence of over-fitting (both from the calibration plot and the calibration slope being significantly different from 1). The model also has poor discrimination (Harrell C). +Here, we see that the existing model under-predicts the mean risk at 5-months (observed:Expected Ratio greater than 1), and there is evidence of over-fitting. The model also has poor discrimination (Harrell C). We can also see this from the plot of the distribution of predicted risks, and the calibration plot, as follows: +```{r, fig.height=6, fig.width=10} +plot(validation_results) +``` ### Specifying the baseline cumulative hazard When validating an existing time-to-event model, the baseline cumulative hazard of the existing CPM should be reported (e.g., from the original model publication). In some cases, this might be reported at discrete follow-up times (like in the above example). Alternatively, the entire baseline cumulative hazard curve may be presented, or indeed a parametric form of the baseline cumulative hazard may be provided. In those situations, one should extract (from the plot) or calculate (from the parametric form) the baseline cumulative hazard at multiple follow-up times of interest (i.e., the follow-up times at which one wishes to validate the model against). However, in some cases, the baseline cumulative hazard of the existing time-to-event CPM may not be reported. In such a situation, one can still use __predRupdate__ to validate such a model. However, only a limited number of metrics will be produced (i.e., only those metrics that require the linear predictor, not absolute risk predictions at a given follow-up time). Specifically, the observed:expected ratio and the calibration plot will not be produced if the baseline cumulative hazard is not provided. -For example, suppose the baseline cumulative hazard for the above model was not available. We could validate this model using __predRupdate__ as follows: -```{r} -# create a data.frame of the model coefficients, with columns being variables -coefs_table <- data.frame("Age" = 0.007, - "SexM" = 0.225, - "Smoking_Status" = 0.685, - "Diabetes" = 0.425, - "Creatine" = 0.587) - -#pass this into pred_input_info() -Existing_TTE_Model <- pred_input_info(model_type = "survival", - model_info = coefs_table, - cum_hazard = NULL) #leave as NULL if the baseline not available - -#now validate against the time-to-event outcomes in the new dataset: -validation_results <- pred_validate(x = Existing_TTE_Model, - new_data = SYNPM$ValidationData, - survival_time = "ETime", - event_indicator = "Status", - time_horizon = 5) -summary(validation_results) -``` - -Here, we see that no calibration plot is produced, and the observed:expected ratio is NA. A warning message is given to highlight this. - # Example 2: model updating on new data In the validation of an existing logistic regression model in Example 1 above, we found that the existing model was miscalibrated in the new data. One strategy to handle this is to apply a range of model updating methods; see `vignette("predRupdate_technical")` for a technical discussion of these methods.