Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

permutations in feature_importance #31

Merged
merged 3 commits into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 70 additions & 47 deletions R/feature_importance.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' Feature Importance Plots
#' Feature Importance
#'
#' This function calculates variable importance based on the drop in the Loss function after single-variable-perturbations.
#' For this reason it is also called the Variable Dropout Plot.
Expand All @@ -14,6 +14,8 @@
#' @param ... other parameters
#' @param type character, type of transformation that should be applied for dropout loss. 'raw' results raw drop lossess, 'ratio' returns \code{drop_loss/drop_loss_full_model} while 'difference' returns \code{drop_loss - drop_loss_full_model}
#' @param n_sample number of observations that should be sampled for calculation of variable importance. If NULL then variable importance will be calculated on whole dataset (no sampling).
#' @param B integer, number of permutation rounds to perform on each variable
#' @param keep_raw_permutations logical or NULL, determines if output retains information for individual permutations; default is to omit for B=1 and keep otherwise
#' @param variables vector of variables. If NULL then variable importance will be tested for each variable from the `data` separately. By default NULL
#' @param variable_groups list of variables names vectors. This is for testing joint variable importance. If NULL then variable importance will be tested separately for `variables`. By default NULL. If specified then it will override `variables`
#'
Expand Down Expand Up @@ -115,8 +117,10 @@ feature_importance <- function(x, ...)
feature_importance.explainer <- function(x,
loss_function = loss_root_mean_square,
...,
type = "raw",
type = c("raw", "ratio", "difference"),
n_sample = NULL,
B = 1,
keep_raw_permutations = NULL,
variables = NULL,
variable_groups = NULL,
label = NULL) {
Expand All @@ -140,6 +144,8 @@ feature_importance.explainer <- function(x,
label = label,
type = type,
n_sample = n_sample,
B = B,
keep_raw_permutations = keep_raw_permutations,
variables = variables,
variable_groups = variable_groups,
...
Expand All @@ -155,28 +161,26 @@ feature_importance.default <- function(x,
loss_function = loss_root_mean_square,
...,
label = class(x)[1],
type = "raw",
type = c("raw", "ratio", "difference"),
n_sample = NULL,
B = 1,
keep_raw_permutations = NULL,
variables = NULL,
variable_groups = NULL) {
if (!is.null(variable_groups)) {
if (!inherits(variable_groups, "list")) stop("variable_groups should be of class list")

wrong_names <- !all(sapply(variable_groups, function(variable_set) {
all(variable_set %in% names(data))
}))

all(variable_set %in% names(data))
}))
if (wrong_names) stop("You have passed wrong variables names in variable_groups argument")
if (!all(sapply(variable_groups, class) == "character")) stop("Elements of variable_groups argument should be of class character")
if (is.null(names(variable_groups))) warning("You have passed an unnamed list. The names of variable groupings will be created from variables names.")

if (is.null(names(variable_groups))) warning("You have passed an unnamed list. The names of variable groupings will be created from variables names.")
}

if (!(type %in% c("difference", "ratio", "raw")))
stop("Type shall be one of 'difference', 'ratio', 'raw'")



type <- match.arg(type)
B <- max(1, round(B))

# Adding variable set name when not specified
if (!is.null(variable_groups) && is.null(names(variable_groups))) {
names(variable_groups) <- sapply(variable_groups, function(variable_set) {
Expand All @@ -195,43 +199,62 @@ feature_importance.default <- function(x,
variables <- variable_groups
}

#variables <- colnames(data)
if (!is.null(n_sample)) {
sampled_rows <- sample.int(nrow(data), n_sample, replace = TRUE)
} else {
sampled_rows <- 1:nrow(data)
# one permutation round: subsample data, permute variables and compute losses
sampled_rows <- 1:nrow(data)
loss_after_permutation <- function() {
if (!is.null(n_sample)) {
sampled_rows <- sample.int(nrow(data), n_sample, replace = TRUE)
}
sampled_data <- data[sampled_rows, ]
observed <- y[sampled_rows]
# loss on the full model or when outcomes are permuted
loss_full <- loss_function(observed, predict_function(x, sampled_data))
loss_baseline <- loss_function(sample(observed), predict_function(x, sampled_data))
# loss upon dropping single variables (or single groups)
loss_features <- sapply(variables, function(variables_set) {
ndf <- sampled_data
ndf[, variables_set] <- ndf[sample(1:nrow(ndf)), variables_set]
predicted <- predict_function(x, ndf)
loss_function(observed, predicted)
})
c("_full_model_" = loss_full, loss_features, "_baseline_" = loss_baseline)
}
sampled_data <- data[sampled_rows, ]
observed <- y[sampled_rows]

loss_0 <- loss_function(observed,
predict_function(x, sampled_data))
loss_full <- loss_function(sample(observed),
predict_function(x, sampled_data))

res <- sapply(variables, function(variables_set) {
ndf <- sampled_data
# sample variables in variables_set
ndf[, variables_set] <- ndf[sample(1:nrow(ndf)), variables_set]

predicted <- predict_function(x, ndf)
loss_function(observed, predicted)
})

res <- sort(res)
res <-
data.frame(
variable = c("_full_model_", names(res), "_baseline_"),
dropout_loss = c(loss_0, res, loss_full)
)
# permute B times, collect results into single matrix
raw <- replicate(B, loss_after_permutation())

# main result df with dropout_loss averages, with _full_model_ first and _baseline_ last
res <- apply(raw, 1, mean)
res_baseline <- res["_baseline_"]
res_full <- res["_full_model_"]
res <- sort(res[!names(res) %in% c("_full_model_", "_baseline_")])
res <- data.frame(
variable = c("_full_model_", names(res), "_baseline_"),
dropout_loss = c(res_full, res, res_baseline),
label = label,
row.names = NULL
)
if (type == "ratio") {
res$dropout_loss = res$dropout_loss / loss_0
res$dropout_loss = res$dropout_loss / res_full
}
if (type == "difference") {
res$dropout_loss = res$dropout_loss - loss_0
res$dropout_loss = res$dropout_loss - res_full
}

class(res) <- c("feature_importance_explainer", "data.frame")
res$label <- label

# record details of permutations
attr(res, "B") <- B
if (is.null(keep_raw_permutations)) {
keep_raw_permutations <- (B > 1)
}
if (keep_raw_permutations) {
attr(res, "raw_permutations") <- data.frame(
variable = rep(rownames(raw), ncol(raw)),
permutation = rep(seq_len(B), each = nrow(raw)),
dropout_loss = as.vector(raw),
label = label
)
}

res
}

16 changes: 11 additions & 5 deletions man/feature_importance.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading