Skip to content

Commit

Permalink
add ms_update_observations
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed May 6, 2020
1 parent 5d30a1f commit 35d1beb
Show file tree
Hide file tree
Showing 13 changed files with 456 additions and 40 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: modelStudio
Title: Interactive Studio for Explanatory Model Analysis
Version: 1.0.2.9000
Version: 1.1.0
Authors@R:
c(person("Hubert", "Baniecki", role = c("aut", "cre"),
email = "hbaniecki@gmail.com",
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ S3method(modelStudio,python.builtin.object)
export(modelStudio)
export(modelStudioOptions)
export(ms_options)
export(ms_update_observations)
export(ms_update_options)
import(progress)
importFrom(grDevices,nclass.Sturges)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# modelStudio (development)
* ...

# modelStudio 1.1.0
* rename `modelStudioOptions()` to `ms_options()`
* add new `ms_update_options()` function that updates the options of a `modelStudio` object
* add new `ms_update_observations()` function that updates the observations of a `modelStudio` object
* lower `B` default value from `15` to `10` and `N` default value from `400` to `300`
* `feature_importance` is now calculated on `10*N` sampled rows from the data
* use `ranger` instead of `randomForest` everywhere
* remove unnecessary imports, update the documentation
* added `auto_unbox = TRUE` to `jsonlite::toJSON` and changed the `.js` code to comply
* add new class `"modelStudio"` to the `modelStudio()` output

# modelStudio 1.0.2
Expand Down
14 changes: 7 additions & 7 deletions R/modelStudio.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#' \href{https://pbiecek.github.io/ema/}{Explanatory Model Analysis: Explore, Explain and Examine Predictive Models}
#'
#' @param explainer An \code{explainer} created with \code{DALEX::explain()}.
#' @param new_observation A new observation with columns that correspond to variables used in the model.
#' @param new_observation New observations with columns that correspond to variables used in the model.
#' @param new_observation_y True label for \code{new_observation} (optional).
#' @param facet_dim Dimensions of the grid. Default is \code{c(2,2)}.
#' @param time Time in ms. Set the animation length. Default is \code{500}.
Expand Down Expand Up @@ -286,8 +286,8 @@ modelStudio.explainer <- function(explainer,
paste0("iBreakDown::local_attributions (", i, ")"), show_info, pb, 2)
sv <- calculate(
iBreakDown::shap(
model, data, predict_function, new_observation, label = label, B = 3*B),
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 2*B)
model, data, predict_function, new_observation, label = label, B = B),
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 3*B)
cp <- calculate(
ingredients::ceteris_paribus(
model, data, predict_function, new_observation, label = label),
Expand Down Expand Up @@ -324,8 +324,8 @@ modelStudio.explainer <- function(explainer,
paste0("iBreakDown::local_attributions (", i, ")"), show_info, pb, 2)
sv <- calculate(
iBreakDown::shap(
model, data, predict_function, new_observation, label = label, B = 3*B),
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 2*B)
model, data, predict_function, new_observation, label = label, B = B),
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 3*B)
cp <- calculate(
ingredients::ceteris_paribus(
model, data, predict_function, new_observation, label = label),
Expand All @@ -341,7 +341,7 @@ modelStudio.explainer <- function(explainer,

# pack explanation data to json and make hash for htmlwidget
names(obs_list) <- rownames(obs_data)
temp <- jsonlite::toJSON(list(obs_list, fi_data, pd_data, ad_data, fd_data, at_data))
temp <- jsonlite::toJSON(list(obs_list, fi_data, pd_data, ad_data, fd_data, at_data), auto_unbox = TRUE)
widget_id <- paste0("widget-", digest::digest(temp))

# prepare observation data for drop down
Expand Down Expand Up @@ -462,7 +462,7 @@ calculate <- function(expr, function_name, show_info = FALSE, pb = NULL, ticks =
expr
},
error = function(e) {
warning(paste0("Error occurred in ", function_name, " function: ", e$message))
warning(paste0("\nError occurred in ", function_name, " function: ", e$message))
NULL
})
}
Expand Down
2 changes: 1 addition & 1 deletion R/ms_options.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ ms_options <- function(...) {
at_point_color = "#371ea3"
)

# input user options
# input new options
default_options[names(list(...))] <- list(...)

default_options
Expand Down
1 change: 0 additions & 1 deletion R/ms_update_new_observation.R

This file was deleted.

273 changes: 273 additions & 0 deletions R/ms_update_observations.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
#' @title Update the observations of a modelStudio object
#'
#' @description
#' This function calculates local explanations on new observations and adds them
#' to a \code{modelStudio} object.
#'
#' @param object A \code{modelStudio} created with \code{modelStudio()}.
#' @param explainer An \code{explainer} created with \code{DALEX::explain()}.
#' @param new_observation New observations with columns that correspond to variables used in the model.
#' @param new_observation_y True label for \code{new_observation} (optional).
#' @param max_features Maximum number of features to be included in BD and SV plots.
#' Default is \code{10}.
#' @param B Number of permutation rounds used for calculation of SV and FI.
#' Default is \code{10}.
#' See \href{https://modelstudio.drwhy.ai/articles/ms-perks-features.html#more-calculations-means-more-time}{\bold{vignette}}
#' @param show_info Verbose a progress on the console. Default is \code{TRUE}.
#' @param parallel Speed up the computation using \code{parallelMap::parallelMap()}.
#' See \href{https://modeloriented.github.io/modelStudio/articles/ms-perks-features.html#parallel-computation}{\bold{vignette}}.
#' This might interfere with showing progress using \code{show_info}.
#' @param overwrite Overwrite existing observations and their explanations.
#' Default is \code{FALSE} which means add new observations to the existing ones.
#' @param ... Other parameters.
#'
#' @return An object of the \code{r2d3, htmlwidget, modelStudio} class.
#'
#' @references
#'
#' \itemize{
#' \item The input object is implemented in \href{https://modeloriented.github.io/DALEX/}{\bold{DALEX}}
#' \item Feature Importance, Ceteris Paribus, Partial Dependence and Accumulated Dependence plots
#' are implemented in \href{https://modeloriented.github.io/ingredients/}{\bold{ingredients}}
#' \item Break Down and Shapley Values plots are implemented in \href{https://modeloriented.github.io/iBreakDown/}{\bold{iBreakDown}}
#' }
#'
#' @seealso
#' Vignettes: \href{https://modeloriented.github.io/modelStudio/articles/ms-r-python-examples.html}{\bold{modelStudio - R & Python examples}}
#' and \href{https://modeloriented.github.io/modelStudio/articles/ms-perks-features.html}{\bold{modelStudio - perks and features}}
#'
#' @examples
#' library("DALEX")
#' library("modelStudio")
#'
#' # fit a model
#' model_apartments <- glm(m2.price ~. , data = apartments)
#'
#' # create an explainer for the model
#' explainer_apartments <- explain(model_apartments,
#' data = apartments,
#' y = apartments$m2.price)
#'
#' # make a studio for the model
#' ms <- modelStudio(explainer_apartments)
#' ms
#'
#' # add new observations
#' ms <- ms_update_observations(ms,
#' explainer_apartments,
#' new_observation = apartments[100:101,],
#' new_observation_y = apartments$m2.price[100:101])
#' ms
#'
#' # overwrite the observations with new ones
#' ms <- ms_update_observations(ms,
#' explainer_apartments,
#' new_observation = apartments[100:101,],
#' overwrite = TRUE)
#' ms
#'
#' @export
#' @rdname ms_update_observations
ms_update_observations <- function(object,
explainer,
new_observation = NULL,
new_observation_y = NULL,
max_features = 10,
B = 10,
show_info = TRUE,
parallel = FALSE,
overwrite = FALSE,
...) {

#:# TODO: clean code in this function #:#

stopifnot("modelStudio" %in% class(object))
stopifnot("explainer" %in% class(explainer))

options <- object$x$options
model <- explainer$model
data <- explainer$data
y <- explainer$y
predict_function <- explainer$predict_function
label <- explainer$label

if (is.null(rownames(data))) {
rownames(data) <- 1:nrow(data)
}

if (is.null(new_observation)) {
if (show_info) message("`new_observation` argument is NULL.\n",
"Observations needed to calculate local explanations are taken at random from the data.\n")
new_observation <- ingredients::select_sample(data, 3)

} else if (is.null(dim(new_observation))) {
warning("`new_observation` argument is not a data.frame nor a matrix, coerced to data.frame\n")
new_observation <- as.data.frame(new_observation)

} else if (is.null(rownames(new_observation))) {
rownames(new_observation) <- 1:nrow(new_observation)
}

check_single_prediction <- try(predict_function(model, new_observation[1,, drop = FALSE]), silent = TRUE)
if (class(check_single_prediction)[1] == "try-error") {
stop("`predict_function` returns an error when executed on `new_observation[1,, drop = FALSE]` \n")
}

## get proper names of features that arent target
is_y <- is_y_in_data(data, y)
potential_variable_names <- names(is_y[!is_y])
variable_names <- intersect(potential_variable_names, colnames(new_observation))
## get rid of target in data
data <- data[,!is_y]

obs_count <- dim(new_observation)[1]
obs_data <- new_observation
obs_list <- list()

## later update progress bar after all explanation functions
if (show_info) {
pb <- progress_bar$new(
format = " Calculating :what \n Elapsed time: :elapsedfull ETA::eta", # :percent [:bar]
total = (3*B + 2 + 1)*obs_count,
show_after = 0
)
pb$tick(0, tokens = list(what = "..."))
}

if (parallel) {
parallelMap::parallelStart()
parallelMap::parallelLibrary(packages = loadedNamespaces())

f <- function(i, model, data, predict_function, label, B, show_boxplot, ...) {
new_observation <- obs_data[i,, drop = FALSE]

bd <- calculate(
iBreakDown::local_attributions(
model, data, predict_function, new_observation, label = label),
paste0("iBreakDown::local_attributions (", i, ")"), show_info, pb, 2)
sv <- calculate(
iBreakDown::shap(
model, data, predict_function, new_observation, label = label, B = B),
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 3*B)
cp <- calculate(
ingredients::ceteris_paribus(
model, data, predict_function, new_observation, label = label),
paste0("ingredients::ceteris_paribus (", i, ")"), show_info, pb, 1)

bd_data <- prepare_break_down(bd, max_features, ...)
sv_data <- prepare_shapley_values(sv, max_features, show_boxplot, ...)
cp_data <- prepare_ceteris_paribus(cp, variables = variable_names)

list(bd_data, cp_data, sv_data)
}

obs_list <- parallelMap::parallelMap(f, 1:obs_count,
more.args = list(
model = model,
data = data,
predict_function = predict_function,
label = label,
B = B,
show_boxplot = options$show_boxplot,
...
))

parallelMap::parallelStop()

} else {
## count once per observation
for(i in 1:obs_count) {
new_observation <- obs_data[i,, drop = FALSE]

bd <- calculate(
iBreakDown::local_attributions(
model, data, predict_function, new_observation, label = label),
paste0("iBreakDown::local_attributions (", i, ")"), show_info, pb, 2)
sv <- calculate(
iBreakDown::shap(
model, data, predict_function, new_observation, label = label, B = B),
paste0("iBreakDown::shap (", i, ")"), show_info, pb, 3*B)
cp <- calculate(
ingredients::ceteris_paribus(
model, data, predict_function, new_observation, label = label),
paste0("ingredients::ceteris_paribus (", i, ")"), show_info, pb, 1)

bd_data <- prepare_break_down(bd, max_features, ...)
sv_data <- prepare_shapley_values(sv, max_features, options$show_boxplot, ...)
cp_data <- prepare_ceteris_paribus(cp, variables = variable_names)

obs_list[[i]] <- list(bd_data, cp_data, sv_data)
}
}

names(obs_list) <- rownames(obs_data)

#:# prepare new observation data for drop down
between <- " - "
if (is.null(new_observation_y)) new_observation_y <- between <- ""
drop_down_data <- as.data.frame(cbind(rownames(obs_data),
paste0(rownames(obs_data), between, new_observation_y)))
colnames(drop_down_data) <- c("id", "text")

#:# extract old data
old_data <- jsonlite::fromJSON(object$x$data, simplifyVector = FALSE)

if (!overwrite) {
#:# extract old drop down and merge with new one
old_drop_down_data <- jsonlite::fromJSON(options$drop_down_data,
simplifyVector = FALSE, simplifyDataFrame = TRUE)
drop_down_data <- rbind(old_drop_down_data, drop_down_data)

#:# update new data
obs_list <- c(old_data[[1]], obs_list)
if (length(unique(names(obs_list))) != length(obs_list)) {
warning("new_observation ids overlap with existing data, using unique ids")
obs_list <- obs_list[unique(names(obs_list))]
drop_down_data <- drop_down_data[!duplicated(drop_down_data$id),]
}
}

#:# input new data
temp <- jsonlite::toJSON(list(obs_list, old_data[[2]], old_data[[3]],
old_data[[4]], old_data[[5]], old_data[[6]]),
auto_unbox = TRUE)
widget_id <- paste0("widget-", digest::digest(temp))

#:# exctract old options and update them
new_options <- options
new_options$widget_id <- widget_id
new_options$variable_names <- variable_names
new_options$footer_text <- paste0("Site built with modelStudio v", packageVersion("modelStudio"),
" on ", format(Sys.time(), usetz = FALSE))
new_options$drop_down_data <- jsonlite::toJSON(drop_down_data)

options("r2d3.shadow" = FALSE) # set this option to avoid using shadow-root

model_studio <- r2d3::r2d3(
data = temp,
script = system.file("d3js/modelStudio.js", package = "modelStudio"),
dependencies = list(
system.file("d3js/hackHead.js", package = "modelStudio"),
system.file("d3js/myTools.js", package = "modelStudio"),
system.file("d3js/d3-tip.js", package = "modelStudio"),
system.file("d3js/d3-simple-slider.min.js", package = "modelStudio"),
system.file("d3js/d3-interpolate-path.min.js", package = "modelStudio"),
system.file("d3js/generatePlots.js", package = "modelStudio"),
system.file("d3js/generateTooltipHtml.js", package = "modelStudio")
),
css = system.file("d3js/modelStudio.css", package = "modelStudio"),
options = new_options,
d3_version = "4",
sizing = object$sizingPolicy,
elementId = widget_id,
width = new_options$facet_dim[2]*(new_options$w + new_options$margin_left + new_options$margin_right),
height = 100 + new_options$facet_dim[1]*(new_options$h + new_options$margin_top + new_options$margin_bottom)
)

model_studio$x$script <- remove_file_paths(model_studio$x$script, "js")
model_studio$x$style <- remove_file_paths(model_studio$x$style, "css")

class(model_studio) <- c(class(model_studio), "modelStudio")

model_studio
}
Loading

0 comments on commit 35d1beb

Please sign in to comment.