Skip to content

Commit

Permalink
Merge pull request #15 from ModelOriented/issue#11
Browse files Browse the repository at this point in the history
update subtitle in plot_
  • Loading branch information
krzyzinskim authored Oct 27, 2022
2 parents 561b62e + 402b83a commit e107cd2
Show file tree
Hide file tree
Showing 21 changed files with 84 additions and 60 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* improved and unified API documentation ([#2](https://github.com/ModelOriented/survex/issues/2))
* added references to used methods ([#5](https://github.com/ModelOriented/survex/issues/5))
* changed the package used to draw complex plots from `gridExtra` to `patchwork` ([#7](https://github.com/ModelOriented/survex/pull/7))
* ...
* fixed subtitles in plots ([#11](https://github.com/ModelOriented/survex/issues/11))

# survex 0.1.1
* The `survex` package is now public
Expand Down
4 changes: 2 additions & 2 deletions R/plot_model_performance_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#' * `metrics` - character, names of metrics to be plotted (subset of C/D AUC", "Brier score" for `metrics_type %in% c("time_dependent", "functional")` or subset of "C-index","Integrated Brier score", "Integrated C/D AUC" for `metrics_type == "scalar"`), by default (`NULL`) all metrics of a given type are plotted
#' * `metrics_type` - character, either one of `c("time_dependent","functional")` for functional metrics or `"scalar"` for scalar metrics
#' * `title` - character, title of the plot
#' * `subtitle` - character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' * `subtitle` - character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' * `facet_ncol` - number of columns for arranging subplots
#' * `colors` - character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
Expand All @@ -26,7 +26,7 @@
#' * `x` - an object of class `"surv_model_performance_rocs"` to be plotted
#' * `...` - additional objects of class `"surv_model_performance_rocs"` to be plotted together
#' * `title` - character, title of the plot
#' * `subtitle` - character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' * `subtitle` - character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' * `colors` - character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#' * `facet_ncol` - number of columns for arranging subplots
#'
Expand Down
8 changes: 4 additions & 4 deletions R/plot_model_profile_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' @param facet_ncol number of columns for arranging subplots
#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
#' @return A grid of `ggplot` objects arranged with the `gridExtra::grid.arrange` function.
Expand Down Expand Up @@ -40,7 +40,7 @@ plot.model_profile_survival <- function(x,
facet_ncol = NULL,
numerical_plot_type = "lines",
title = "Partial dependence survival profile",
subtitle = NULL,
subtitle = "default",
colors = NULL)
{

Expand All @@ -61,7 +61,7 @@ plot.model_profile_survival <- function(x,
aggregated_profiles[aggregated_profiles$`_vname_` %in% all_variables, ]


if (is.null(subtitle)) {
if (!is.null(subtitle) && subtitle == "default") {
labels <-
paste0(unique(aggregated_profiles$`_label_`), collapse = ", ")
subtitle <- paste0("created for the ", labels, " model")
Expand All @@ -83,7 +83,7 @@ plot.model_profile_survival <- function(x,

aggregated_profiles$`_real_point_` <- FALSE

pl <- plot_individual_ceteris_paribus_survival(aggregated_profiles, variables, facet_ncol, colors, numerical_plot_type, title, subtitle)
pl <- plot_individual_ceteris_paribus_survival(aggregated_profiles, variables, facet_ncol, colors, numerical_plot_type, title)

patchwork::wrap_plots(pl, ncol = facet_ncol) +
patchwork::plot_annotation(title = title,
Expand Down
4 changes: 2 additions & 2 deletions R/plot_predict_parts_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' * `x` - an object of class `"surv_shap"` to be plotted
#' * `...` - additional objects of class `surv_shap` to be plotted together
#' * `title` - character, title of the plot
#' * `subtitle` - character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' * `subtitle` - character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' * `colors` - character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
#' ## `plot.surv_lime`
Expand All @@ -25,7 +25,7 @@
#' * `show_survival_function` - logical, if the survival function of the explanations should be plotted next to the barplot
#' * `...` - other parameters currently ignored
#' * `title` - character, title of the plot
#' * `subtitle` - character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' * `subtitle` - character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' * `colors` - character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
#' @family functions for plotting 'predict_parts_survival' objects
Expand Down
2 changes: 1 addition & 1 deletion R/plot_predict_profile_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#' * `variables` - character, names of the variables to be plotted
#' * `numerical_plot_type` - character, either `"lines"`, or `"contours"` selects the type of numerical variable plots
#' * `title` - character, title of the plot
#' * `subtitle` - character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' * `subtitle` - character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#'
#'
#' @family functions for plotting 'predict_profile_survival' objects
Expand Down
18 changes: 8 additions & 10 deletions R/plot_surv_ceteris_paribus.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#' @param variables character, names of the variables to be plotted
#' @param numerical_plot_type character, either `"lines"`, or `"contours"` selects the type of numerical variable plots
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#'
#' @return A grid of `ggplot` objects arranged with the `gridExtra::grid.arrange` function.
#'
Expand Down Expand Up @@ -47,7 +47,7 @@ plot.surv_ceteris_paribus <- function(x,
variables = NULL,
numerical_plot_type = "lines",
title = "Ceteris paribus survival profile",
subtitle = NULL) {
subtitle = "default") {
if (!is.null(variable_type))
check_variable_type(variable_type)

Expand All @@ -62,7 +62,7 @@ plot.surv_ceteris_paribus <- function(x,
all_profiles$`_ids_` <- factor(all_profiles$`_ids_`)

# extract labels to use in the default subtitle
if (is.null(subtitle)) {
if (!is.null(subtitle) && subtitle == "default") {
labels <- paste0(unique(all_profiles$`_label_`), collapse = ", ")
subtitle <- paste0("created for the ", labels, " model")
}
Expand Down Expand Up @@ -118,8 +118,7 @@ plot.surv_ceteris_paribus <- function(x,
facet_ncol = facet_ncol,
colors = colors,
numerical_plot_type = numerical_plot_type,
title = title,
subtitle = subtitle
title = title
)

patchwork::wrap_plots(pl, ncol = facet_ncol) +
Expand All @@ -135,8 +134,7 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles,
facet_ncol,
colors,
numerical_plot_type,
title,
subtitle) {
title) {
pl <- lapply(variables, function(var) {
df <- all_profiles[all_profiles$`_vname_` == var, ]

Expand Down Expand Up @@ -210,11 +208,11 @@ plot_individual_ceteris_paribus_survival <- function(all_profiles,
size = 0.8) +
geom_line(data = df[df$`_real_point_`, ],
size = 0.8, linetype = "longdash") +
scale_color_manual(name = "",
scale_color_manual(name = paste0(unique(df$`_vname_`), " value"),
values = generate_discrete_color_scale(n_colors, colors)) +
facet_wrap(~`_vname_`, scales = "free_x", ncol = facet_ncol) +
theme_drwhy() +
xlab("") + ylab("survival function value") + ylim(c(0, 1))
xlab("") + ylab("survival function value") + ylim(c(0, 1)) +
facet_wrap(~`_vname_`, ncol = facet_ncol)
}
})
}
Expand Down
12 changes: 6 additions & 6 deletions R/plot_surv_feature_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' @param x an object of class `"surv_feature_importance"` to be plotted
#' @param ... additional objects of class `"surv_feature_importance"` to be plotted together
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' @param max_vars maximum number of variables to be plotted (least important variables are ignored)
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
Expand All @@ -33,10 +33,10 @@
#'
#' @export
plot.surv_feature_importance <- function(x, ...,
title = "Time-dependent feature importance",
subtitle = NULL,
max_vars = 6,
colors = NULL) {
title = "Time-dependent feature importance",
subtitle = "default",
max_vars = 6,
colors = NULL) {

df_list <- c(list(x), list(...))

Expand Down Expand Up @@ -70,7 +70,7 @@ plot.surv_feature_importance <- function(x, ...,
y_lab <- paste0("Loss function after variable's permutations", additional_info)
}

if (is.null(subtitle)) {
if (!is.null(subtitle) && subtitle == "default") {
glm_labels <- paste0(label, collapse = ", ")
subtitle <- paste0("created for the ", glm_labels, " model")
}
Expand Down
14 changes: 10 additions & 4 deletions R/plot_surv_lime.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @param show_survival_function logical, if the survival function of the explanations should be plotted next to the barplot
#' @param ... other parameters currently ignored
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
#' @return An object of the class `ggplot`.
Expand All @@ -26,7 +26,13 @@
#' plot(p_parts_lime)
#'
#' @export
plot.surv_lime <- function(x, type = "local_importance", show_survival_function = TRUE, ..., title = "SurvLIME", subtitle = NULL, colors = NULL) {
plot.surv_lime <- function(x,
type = "local_importance",
show_survival_function = TRUE,
...,
title = "SurvLIME",
subtitle = "default",
colors = NULL) {
if (!type %in% c("coefficients", "local_importance"))
stop("Type should be one of `coefficients`, `local_importance`")

Expand All @@ -40,11 +46,11 @@ plot.surv_lime <- function(x, type = "local_importance", show_survival_function
sign_local_importance = as.factor(sign(x$beta * x$variable_values)),
local_importance = x$beta * x$variable_values)

if (is.null(subtitle))
if (!is.null(subtitle) && subtitle == "default") {
subtitle <- paste0("created for the ", attr(x, "label"), " model")
}

if (type == "coefficients") {

x_lab <- "SurvLIME coefficients"
y_lab <- ""
pl <- with(df, {
Expand Down
19 changes: 14 additions & 5 deletions R/plot_surv_model_performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#' @param metrics character, names of metrics to be plotted (subset of C/D AUC", "Brier score" for `metrics_type %in% c("time_dependent", "functional")` or subset of "C-index","Integrated Brier score", "Integrated C/D AUC" for `metrics_type == "scalar"`), by default (`NULL`) all metrics of a given type are plotted
#' @param metrics_type character, either one of `c("time_dependent","functional")` for functional metrics or `"scalar"` for scalar metrics
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, if `NULL` automatically generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' @param facet_ncol number of columns for arranging subplots
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
Expand All @@ -27,7 +27,14 @@
#' plot(m_perf)
#'
#' @export
plot.surv_model_performance <- function(x, ..., metrics = NULL, metrics_type = "time_dependent", title = "Model performance", subtitle = NULL, facet_ncol = NULL, colors = NULL) {
plot.surv_model_performance <- function(x,
...,
metrics = NULL,
metrics_type = "time_dependent",
title = "Model performance",
subtitle = "default",
facet_ncol = NULL,
colors = NULL) {


if (metrics_type %in% c("time_dependent", "functional")) {
Expand All @@ -45,12 +52,13 @@ plot.surv_model_performance <- function(x, ..., metrics = NULL, metrics_type = "


#' @importFrom DALEX theme_drwhy
plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, subtitle = NULL, facet_ncol = NULL, colors = NULL) {
plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, subtitle = "default", facet_ncol = NULL, colors = NULL) {

df <- concatenate_td_dfs(x, ...)

if (is.null(subtitle))
if (!is.null(subtitle) && subtitle == "default") {
subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), " model")
}

if (is.null(metrics)) metrics <- c("C/D AUC", "Brier score")

Expand All @@ -71,8 +79,9 @@ plot_td_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL,
plot_scalar_surv_model_performance <- function(x, ..., metrics = NULL, title = NULL, subtitle = NULL, facet_ncol = NULL, colors = NULL) {
df <- concatenate_dfs(x, ...)

if (is.null(subtitle))
if (!is.null(subtitle) && subtitle == "default") {
subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), " model")
}

if (is.null(metrics)) metrics <- c("C-index", "Integrated Brier score", "Integrated C/D AUC")

Expand Down
12 changes: 9 additions & 3 deletions R/plot_surv_model_performance_rocs.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @param x an object of class `"surv_model_performance_rocs"` to be plotted
#' @param ... additional objects of class `"surv_model_performance_rocs"` to be plotted together
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, if `NULL` automaticaly generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#' @param facet_ncol number of columns for arranging subplots
#'
Expand All @@ -25,7 +25,12 @@
#' plot(m_perf_roc)
#'
#' @export
plot.surv_model_performance_rocs <- function(x, ..., title = "ROC curves for selected timepoints", subtitle = NULL, colors = NULL, facet_ncol = NULL) {
plot.surv_model_performance_rocs <- function(x,
...,
title = "ROC curves for selected timepoints",
subtitle = "default",
colors = NULL,
facet_ncol = NULL) {

dfl <- c(list(x), list(...))

Expand All @@ -39,8 +44,9 @@ plot.surv_model_performance_rocs <- function(x, ..., title = "ROC curves for sel

df$time_formatted <- paste0("t=", df$time)

if (is.null(subtitle))
if (!is.null(subtitle) && subtitle == "default") {
subtitle <- paste0("created for the ", paste(unique(df$label), collapse = ", "), " model")
}

num_colors <- length(unique(df$label))

Expand Down
11 changes: 8 additions & 3 deletions R/plot_surv_shap.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @param x an object of class `"surv_shap"` to be plotted
#' @param ... additional objects of class `surv_shap` to be plotted together
#' @param title character, title of the plot
#' @param subtitle character, subtitle of the plot, if `NULL` automatically generated as "created for XXX, YYY models", where XXX and YYY are explainer labels
#' @param subtitle character, subtitle of the plot, `'default'` automatically generates "created for XXX, YYY models", where XXX and YYY are the explainer labels
#' @param colors character vector containing the colors to be used for plotting variables (containing either hex codes "#FF69B4", or names "blue")
#'
#' @return An object of the class `ggplot`.
Expand All @@ -26,7 +26,11 @@
#' }
#'
#' @export
plot.surv_shap <- function(x, ..., title = "SurvSHAP(t)", subtitle = NULL, colors = NULL) {
plot.surv_shap <- function(x,
...,
title = "SurvSHAP(t)",
subtitle = "default",
colors = NULL) {

dfl <- c(list(x), list(...))

Expand All @@ -46,8 +50,9 @@ plot.surv_shap <- function(x, ..., title = "SurvSHAP(t)", subtitle = NULL, color
long_df <- do.call(rbind, long_df)
label <- unique(long_df$label)

if (is.null(subtitle))
if (!is.null(subtitle) && subtitle == "default") {
subtitle <- paste0("created for the ", paste(label, collapse = ", "), " model")
}

n_colors <- length(unique(long_df$ind))

Expand Down
Loading

0 comments on commit e107cd2

Please sign in to comment.