-
-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
456 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
#' @title Update the observations of a modelStudio object | ||
#' | ||
#' @description | ||
#' This function calculates local explanations on new observations and adds them | ||
#' to a \code{modelStudio} object. | ||
#' | ||
#' @param object A \code{modelStudio} created with \code{modelStudio()}. | ||
#' @param explainer An \code{explainer} created with \code{DALEX::explain()}. | ||
#' @param new_observation New observations with columns that correspond to variables used in the model. | ||
#' @param new_observation_y True label for \code{new_observation} (optional). | ||
#' @param max_features Maximum number of features to be included in BD and SV plots. | ||
#' Default is \code{10}. | ||
#' @param B Number of permutation rounds used for calculation of SV and FI. | ||
#' Default is \code{10}. | ||
#' See \href{https://modelstudio.drwhy.ai/articles/ms-perks-features.html#more-calculations-means-more-time}{\bold{vignette}} | ||
#' @param show_info Verbose a progress on the console. Default is \code{TRUE}. | ||
#' @param parallel Speed up the computation using \code{parallelMap::parallelMap()}. | ||
#' See \href{https://modeloriented.github.io/modelStudio/articles/ms-perks-features.html#parallel-computation}{\bold{vignette}}. | ||
#' This might interfere with showing progress using \code{show_info}. | ||
#' @param overwrite Overwrite existing observations and their explanations. | ||
#' Default is \code{FALSE} which means add new observations to the existing ones. | ||
#' @param ... Other parameters. | ||
#' | ||
#' @return An object of the \code{r2d3, htmlwidget, modelStudio} class. | ||
#' | ||
#' @references | ||
#' | ||
#' \itemize{ | ||
#' \item The input object is implemented in \href{https://modeloriented.github.io/DALEX/}{\bold{DALEX}} | ||
#' \item Feature Importance, Ceteris Paribus, Partial Dependence and Accumulated Dependence plots | ||
#' are implemented in \href{https://modeloriented.github.io/ingredients/}{\bold{ingredients}} | ||
#' \item Break Down and Shapley Values plots are implemented in \href{https://modeloriented.github.io/iBreakDown/}{\bold{iBreakDown}} | ||
#' } | ||
#' | ||
#' @seealso | ||
#' Vignettes: \href{https://modeloriented.github.io/modelStudio/articles/ms-r-python-examples.html}{\bold{modelStudio - R & Python examples}} | ||
#' and \href{https://modeloriented.github.io/modelStudio/articles/ms-perks-features.html}{\bold{modelStudio - perks and features}} | ||
#' | ||
#' @examples | ||
#' library("DALEX") | ||
#' library("modelStudio") | ||
#' | ||
#' # fit a model | ||
#' model_apartments <- glm(m2.price ~. , data = apartments) | ||
#' | ||
#' # create an explainer for the model | ||
#' explainer_apartments <- explain(model_apartments, | ||
#' data = apartments, | ||
#' y = apartments$m2.price) | ||
#' | ||
#' # make a studio for the model | ||
#' ms <- modelStudio(explainer_apartments) | ||
#' ms | ||
#' | ||
#' # add new observations | ||
#' ms <- ms_update_observations(ms, | ||
#' explainer_apartments, | ||
#' new_observation = apartments[100:101,], | ||
#' new_observation_y = apartments$m2.price[100:101]) | ||
#' ms | ||
#' | ||
#' # overwrite the observations with new ones | ||
#' ms <- ms_update_observations(ms, | ||
#' explainer_apartments, | ||
#' new_observation = apartments[100:101,], | ||
#' overwrite = TRUE) | ||
#' ms | ||
#' | ||
#' @export | ||
#' @rdname ms_update_observations | ||
ms_update_observations <- function(object, | ||
explainer, | ||
new_observation = NULL, | ||
new_observation_y = NULL, | ||
max_features = 10, | ||
B = 10, | ||
show_info = TRUE, | ||
parallel = FALSE, | ||
overwrite = FALSE, | ||
...) { | ||
|
||
#:# TODO: clean code in this function #:# | ||
|
||
stopifnot("modelStudio" %in% class(object)) | ||
stopifnot("explainer" %in% class(explainer)) | ||
|
||
options <- object$x$options | ||
model <- explainer$model | ||
data <- explainer$data | ||
y <- explainer$y | ||
predict_function <- explainer$predict_function | ||
label <- explainer$label | ||
|
||
if (is.null(rownames(data))) { | ||
rownames(data) <- 1:nrow(data) | ||
} | ||
|
||
if (is.null(new_observation)) { | ||
if (show_info) message("`new_observation` argument is NULL.\n", | ||
"Observations needed to calculate local explanations are taken at random from the data.\n") | ||
new_observation <- ingredients::select_sample(data, 3) | ||
|
||
} else if (is.null(dim(new_observation))) { | ||
warning("`new_observation` argument is not a data.frame nor a matrix, coerced to data.frame\n") | ||
new_observation <- as.data.frame(new_observation) | ||
|
||
} else if (is.null(rownames(new_observation))) { | ||
rownames(new_observation) <- 1:nrow(new_observation) | ||
} | ||
|
||
check_single_prediction <- try(predict_function(model, new_observation[1,, drop = FALSE]), silent = TRUE) | ||
if (class(check_single_prediction)[1] == "try-error") { | ||
stop("`predict_function` returns an error when executed on `new_observation[1,, drop = FALSE]` \n") | ||
} | ||
|
||
## get proper names of features that arent target | ||
is_y <- is_y_in_data(data, y) | ||
potential_variable_names <- names(is_y[!is_y]) | ||
variable_names <- intersect(potential_variable_names, colnames(new_observation)) | ||
## get rid of target in data | ||
data <- data[,!is_y] | ||
|
||
obs_count <- dim(new_observation)[1] | ||
obs_data <- new_observation | ||
obs_list <- list() | ||
|
||
## later update progress bar after all explanation functions | ||
if (show_info) { | ||
pb <- progress_bar$new( | ||
format = " Calculating :what \n Elapsed time: :elapsedfull ETA::eta", # :percent [:bar] | ||
total = (3*B + 2 + 1)*obs_count, | ||
show_after = 0 | ||
) | ||
pb$tick(0, tokens = list(what = "...")) | ||
} | ||
|
||
if (parallel) { | ||
parallelMap::parallelStart() | ||
parallelMap::parallelLibrary(packages = loadedNamespaces()) | ||
|
||
f <- function(i, model, data, predict_function, label, B, show_boxplot, ...) { | ||
new_observation <- obs_data[i,, drop = FALSE] | ||
|
||
bd <- calculate( | ||
iBreakDown::local_attributions( | ||
model, data, predict_function, new_observation, label = label), | ||
paste0("iBreakDown::local_attributions (", i, ")"), show_info, pb, 2) | ||
sv <- calculate( | ||
iBreakDown::shap( | ||
model, data, predict_function, new_observation, label = label, B = B), | ||
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 3*B) | ||
cp <- calculate( | ||
ingredients::ceteris_paribus( | ||
model, data, predict_function, new_observation, label = label), | ||
paste0("ingredients::ceteris_paribus (", i, ")"), show_info, pb, 1) | ||
|
||
bd_data <- prepare_break_down(bd, max_features, ...) | ||
sv_data <- prepare_shapley_values(sv, max_features, show_boxplot, ...) | ||
cp_data <- prepare_ceteris_paribus(cp, variables = variable_names) | ||
|
||
list(bd_data, cp_data, sv_data) | ||
} | ||
|
||
obs_list <- parallelMap::parallelMap(f, 1:obs_count, | ||
more.args = list( | ||
model = model, | ||
data = data, | ||
predict_function = predict_function, | ||
label = label, | ||
B = B, | ||
show_boxplot = options$show_boxplot, | ||
... | ||
)) | ||
|
||
parallelMap::parallelStop() | ||
|
||
} else { | ||
## count once per observation | ||
for(i in 1:obs_count) { | ||
new_observation <- obs_data[i,, drop = FALSE] | ||
|
||
bd <- calculate( | ||
iBreakDown::local_attributions( | ||
model, data, predict_function, new_observation, label = label), | ||
paste0("iBreakDown::local_attributions (", i, ")"), show_info, pb, 2) | ||
sv <- calculate( | ||
iBreakDown::shap( | ||
model, data, predict_function, new_observation, label = label, B = B), | ||
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 3*B) | ||
cp <- calculate( | ||
ingredients::ceteris_paribus( | ||
model, data, predict_function, new_observation, label = label), | ||
paste0("ingredients::ceteris_paribus (", i, ")"), show_info, pb, 1) | ||
|
||
bd_data <- prepare_break_down(bd, max_features, ...) | ||
sv_data <- prepare_shapley_values(sv, max_features, options$show_boxplot, ...) | ||
cp_data <- prepare_ceteris_paribus(cp, variables = variable_names) | ||
|
||
obs_list[[i]] <- list(bd_data, cp_data, sv_data) | ||
} | ||
} | ||
|
||
names(obs_list) <- rownames(obs_data) | ||
|
||
#:# prepare new observation data for drop down | ||
between <- " - " | ||
if (is.null(new_observation_y)) new_observation_y <- between <- "" | ||
drop_down_data <- as.data.frame(cbind(rownames(obs_data), | ||
paste0(rownames(obs_data), between, new_observation_y))) | ||
colnames(drop_down_data) <- c("id", "text") | ||
|
||
#:# extract old data | ||
old_data <- jsonlite::fromJSON(object$x$data, simplifyVector = FALSE) | ||
|
||
if (!overwrite) { | ||
#:# extract old drop down and merge with new one | ||
old_drop_down_data <- jsonlite::fromJSON(options$drop_down_data, | ||
simplifyVector = FALSE, simplifyDataFrame = TRUE) | ||
drop_down_data <- rbind(old_drop_down_data, drop_down_data) | ||
|
||
#:# update new data | ||
obs_list <- c(old_data[[1]], obs_list) | ||
if (length(unique(names(obs_list))) != length(obs_list)) { | ||
warning("new_observation ids overlap with existing data, using unique ids") | ||
obs_list <- obs_list[unique(names(obs_list))] | ||
drop_down_data <- drop_down_data[!duplicated(drop_down_data$id),] | ||
} | ||
} | ||
|
||
#:# input new data | ||
temp <- jsonlite::toJSON(list(obs_list, old_data[[2]], old_data[[3]], | ||
old_data[[4]], old_data[[5]], old_data[[6]]), | ||
auto_unbox = TRUE) | ||
widget_id <- paste0("widget-", digest::digest(temp)) | ||
|
||
#:# exctract old options and update them | ||
new_options <- options | ||
new_options$widget_id <- widget_id | ||
new_options$variable_names <- variable_names | ||
new_options$footer_text <- paste0("Site built with modelStudio v", packageVersion("modelStudio"), | ||
" on ", format(Sys.time(), usetz = FALSE)) | ||
new_options$drop_down_data <- jsonlite::toJSON(drop_down_data) | ||
|
||
options("r2d3.shadow" = FALSE) # set this option to avoid using shadow-root | ||
|
||
model_studio <- r2d3::r2d3( | ||
data = temp, | ||
script = system.file("d3js/modelStudio.js", package = "modelStudio"), | ||
dependencies = list( | ||
system.file("d3js/hackHead.js", package = "modelStudio"), | ||
system.file("d3js/myTools.js", package = "modelStudio"), | ||
system.file("d3js/d3-tip.js", package = "modelStudio"), | ||
system.file("d3js/d3-simple-slider.min.js", package = "modelStudio"), | ||
system.file("d3js/d3-interpolate-path.min.js", package = "modelStudio"), | ||
system.file("d3js/generatePlots.js", package = "modelStudio"), | ||
system.file("d3js/generateTooltipHtml.js", package = "modelStudio") | ||
), | ||
css = system.file("d3js/modelStudio.css", package = "modelStudio"), | ||
options = new_options, | ||
d3_version = "4", | ||
sizing = object$sizingPolicy, | ||
elementId = widget_id, | ||
width = new_options$facet_dim[2]*(new_options$w + new_options$margin_left + new_options$margin_right), | ||
height = 100 + new_options$facet_dim[1]*(new_options$h + new_options$margin_top + new_options$margin_bottom) | ||
) | ||
|
||
model_studio$x$script <- remove_file_paths(model_studio$x$script, "js") | ||
model_studio$x$style <- remove_file_paths(model_studio$x$style, "css") | ||
|
||
class(model_studio) <- c(class(model_studio), "modelStudio") | ||
|
||
model_studio | ||
} |
Oops, something went wrong.