Skip to content

Commit

Permalink
feature #94
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed Nov 16, 2020
1 parent be25d72 commit fb0797d
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 9 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: modelStudio
Title: Interactive Studio for Explanatory Model Analysis
Version: 2.0.0.9000
Version: 2.1.0
Authors@R:
c(person("Hubert", "Baniecki", role = c("aut", "cre"),
email = "hbaniecki@gmail.com",
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# modelStudio (development)
# modelStudio 2.1.0
* **DEFAULTS CHANGES**: if `new_observation = NULL` then choose `new_observation_n = 3` observations, evenly spread by the histogram bins of `y_hat`. This shall always include the observations, which ids are `which.min(y_hat)` and `which.max(y_hat)`. Additionally, improve the observation dropdown text in dashboard. [(#94)](https://github.com/ModelOriented/modelStudio/issues/94)
* This version requires `DALEX v2.0.1`
* added new options to `ms_options`: `ms_subtitle`, `ms_margin_top` and `ms_margin_bottom`
* added new parameters to `modelStudio()`: `N_fi = 10*N` and `B_fi = B`
Expand Down
51 changes: 44 additions & 7 deletions R/modelStudio.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#' @param explainer An \code{explainer} created with \code{DALEX::explain()}.
#' @param new_observation New observations with columns that correspond to variables used in the model.
#' @param new_observation_y True label for \code{new_observation} (optional).
#' @param new_observation_n Number of observations to be taken from the \code{explainer$data} if \code{new_observation = NULL}.
#' See \href{https://modelstudio.drwhy.ai/articles/ms-perks-features.html#instance-explanations}{\bold{vignette}}
#' @param facet_dim Dimensions of the grid. Default is \code{c(2,2)}.
#' @param time Time in ms. Set the animation length. Default is \code{500}.
#' @param max_features Maximum number of features to be included in BD and SV plots.
Expand Down Expand Up @@ -169,6 +171,7 @@ modelStudio <- function(explainer, ...) {
modelStudio.explainer <- function(explainer,
new_observation = NULL,
new_observation_y = NULL,
new_observation_n = 3,
facet_dim = c(2,2),
time = 500,
max_features = 10,
Expand Down Expand Up @@ -204,8 +207,10 @@ modelStudio.explainer <- function(explainer,

if (is.null(new_observation)) {
if (show_info) message("`new_observation` argument is NULL.\n",
"Observations needed to calculate local explanations are taken at random from the data.\n")
new_observation <- ingredients::select_sample(data, 3)
"`new_observation_n` observations needed to calculate local explanations are taken from the data.\n")
ret <- sample_new_observation(explainer, new_observation_n)
new_observation <- ret[['no']]
new_observation_y <- ret[['no_y']]

} else if (is.null(dim(new_observation))) {
warning("`new_observation` argument is not a data.frame nor a matrix, coerced to data.frame\n")
Expand Down Expand Up @@ -421,11 +426,13 @@ modelStudio.explainer <- function(explainer,
paste0("widget-", digest::digest(temp)))

# prepare observation data for drop down
between <- " - "
if (is.null(new_observation_y)) new_observation_y <- between <- ""
drop_down_data <- as.data.frame(cbind(rownames(obs_data),
paste0(rownames(obs_data), between, new_observation_y)),
stringsAsFactors=TRUE)
str_between <- " | y: "
str_before <- "id: "
if (is.null(new_observation_y)) new_observation_y <- str_between <- str_before <- ""
drop_down_data <- as.data.frame(
cbind(rownames(obs_data),
paste0(str_before, rownames(obs_data), str_between, new_observation_y)),
stringsAsFactors=TRUE)
colnames(drop_down_data) <- c("id", "text")

# prepare footer text and ms title
Expand Down Expand Up @@ -600,6 +607,7 @@ is_binary <- function(y) {
is.numeric(y) & length(unique(y)) == 2
}

# safety check for explainer
check_explainer <- function(explainer) {

if (is.null(explainer$data))
Expand Down Expand Up @@ -628,4 +636,33 @@ check_explainer <- function(explainer) {
explainer
}

# choose observations
sample_new_observation <- function(explainer, new_observation_n = 3) {
if (is.null(explainer$y_hat)) {
y_hat <- try(predict(explainer), silent = TRUE)
if (class(y_hat)[1] == "try-error")
stop('`predict(explainer)` returns an error')
} else {
y_hat <- explainer$y_hat
}

n <- dim(explainer$data)[1]

if (new_observation_n >= n) {
new_observation_n <- n
}

if (new_observation_n == 1) {
ids <- which.min(y_hat)
} else if (new_observation_n == 2) {
ids <- c(which.min(y_hat), which.max(y_hat))
} else if (new_observation_n == 3) {
ids <- c(which.min(y_hat), as.integer(n/2), which.max(y_hat))
} else {
y_hat_coded <- cut(y_hat, new_observation_n - 2, labels=FALSE)
ids <- sapply(1:(new_observation_n - 2), FUN = function (x) match(x, y_hat_coded))
ids <- c(which.min(y_hat), ids, which.max(y_hat))
}

list(no = explainer$data[ids,], no_y = explainer$y[ids])
}
4 changes: 4 additions & 0 deletions man/modelStudio.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions vignettes/ms-perks-features.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ modelStudio(explainer,
new_observation_y = true_labels)
```

If `new_observation = NULL`, then choose `new_observation_n` observations, evenly spread by the histogram bins of `y_hat`. This shall always include the observations, which ids are `which.min(y_hat)` and `which.max(y_hat)`.

```{r eval = FALSE}
modelStudio(explainer, new_observation_n = 5) # default is 3
```

### grid size

Achieve bigger or smaller `modelStudio` grid with `facet_dim` parameter.
Expand Down

0 comments on commit fb0797d

Please sign in to comment.