Skip to content

Commit

Permalink
Splits feature_combinations into multiple functions (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolase90 committed Aug 19, 2019
1 parent ebfae35 commit f38bf2e
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 119 deletions.
4 changes: 3 additions & 1 deletion NAMESPACE
Expand Up @@ -6,7 +6,9 @@ export(cluster_features)
export(compute_kshap)
export(correction_matrix_cpp)
export(feature_combinations)
export(feature_exact)
export(feature_matrix_cpp)
export(feature_not_exact)
export(gaussian_transform)
export(gaussian_transform_separate)
export(global_arguments)
Expand All @@ -15,14 +17,14 @@ export(inv_gaussian_transform)
export(mahalanobis_distance_cpp)
export(observation_impute)
export(observation_impute_cpp)
export(observation_weights)
export(plot_kshap)
export(prediction_vector)
export(predictions)
export(prepare_kshap)
export(rss_cpp)
export(sample_combinations)
export(sample_copula)
export(sample_features_cpp)
export(sample_gaussian)
export(shapley_weights)
export(weight_matrix)
Expand Down
8 changes: 8 additions & 0 deletions R/RcppExports.R
Expand Up @@ -90,6 +90,14 @@ mahalanobis_distance_cpp <- function(featureList, Xtrain_mat, Xtest_mat, mcov, S
.Call(`_shapr_mahalanobis_distance_cpp`, featureList, Xtrain_mat, Xtest_mat, mcov, S_scale_dist)
}

#' @keywords internal
#'
#' @export
#'
sample_features_cpp <- function(m, nfeatures) {
.Call(`_shapr_sample_features_cpp`, m, nfeatures)
}

#' Get imputed data
#'
#' @param index_xtrain Positive integer. Represents a sequence of row indices from \code{xtrain},
Expand Down
181 changes: 104 additions & 77 deletions R/features.R
Expand Up @@ -6,8 +6,8 @@
#' The returned data.table contains the following columns
#' \describe{
#' \item{ID}{Positive integer. Unique key for combination}
#' \item{features}{List}
#' \item{nfeautres}{Positive integer}
#' \item{features}{List of integer vectors}
#' \item{nfeatures}{Positive integer}
#' \item{N}{Positive integer}
#' }
#'
Expand All @@ -16,88 +16,115 @@
#' @export
#'
#' @author Nikolai Sellereite, Martin Jullum
feature_combinations <- function(m, exact = TRUE, noSamp = 200, shapley_weight_inf_replacement = 10^6, reduce_dim = T) {
if (!exact && noSamp > (2^m - 2) && !replace) {
feature_combinations <- function(m, exact = TRUE, noSamp = 200, shapley_weight_inf_replacement = 10^6, reduce_dim = TRUE) {

# Not supported for m > 30
if (m > 30) {
stop("Currently we are not supporting cases where m > 30.")
}

if (!exact && noSamp > (2^m - 2) && !reduce_dim) {
noSamp <- 2^m - 2
cat(paste0("noSamp is larger than 2^m = ", 2^m, ". Using exact instead."))
cat(sprintf("noSamp is larger than 2^m = %d. Using exact instead.", 2^m))
}
if (exact == TRUE) {
N <- 2^m
X <- data.table(ID = 1:N)
combinations <- lapply(0:m, utils::combn, x = m, simplify = FALSE)
X[, features := unlist(combinations, recursive = FALSE)]
X[, nfeatures := length(features[[1]]), ID]
X[, N := .N, nfeatures]
X[!(nfeatures %in% c(0, m)), shapley_weight := shapley_weights(m = m, N = N, s = nfeatures)]
X[nfeatures %in% c(0, m), shapley_weight := shapley_weight_inf_replacement]
X[, no := 1]

if (exact) {
dt <- feature_exact(m, shapley_weight_inf_replacement)
} else {
## Find weights for given number of features ----------
DT0 <- data.table(nfeatures = head(1:m, -1))
DT0[, N := unlist(lapply(nfeatures, choose, n = m))]
DT0[, shapley_weight := shapley_weights(m = m, N = N, s = nfeatures)]
DT0[, samp_weight := shapley_weight * N]
DT0[, samp_weight := samp_weight / sum(samp_weight)]

## Sample number of features ----------
X <- data.table(
nfeatures = sample(
x = DT0[["nfeatures"]],
dt <- feature_not_exact(m, noSamp, shapley_weight_inf_replacement, reduce_dim)
}

return(dt)
}

#' @keywords internal
#' @export
feature_exact <- function(m, shapley_weight_inf_replacement = 10^6) {

dt <- data.table::data.table(ID = seq(2^m))
combinations <- lapply(0:m, utils::combn, x = m, simplify = FALSE)
dt[, features := unlist(combinations, recursive = FALSE)]
dt[, nfeatures := length(features[[1]]), ID]
dt[, N := .N, nfeatures]
dt[, shapley_weight := shapley_weights(m = m, N = N, s = nfeatures, shapley_weight_inf_replacement)]
dt[, no := 1]

return(dt)
}

#' @keywords internal
#' @export
feature_not_exact <- function(m, noSamp = 200, shapley_weight_inf_replacement = 10^6, reduce_dim = TRUE) {

# Find weights for given number of features ----------
nfeatures <- seq(m - 1)
n <- sapply(nfeatures, choose, n = m)
w <- shapley_weights(m = m, N = n, s = nfeatures) * n
p <- w / sum(w)

# Sample number of chosen features ----------
X <- data.table::data.table(
nfeatures = c(
0,
sample(
x = nfeatures,
size = noSamp,
replace = TRUE,
prob = DT0[["samp_weight"]]
)
prob = p
),
m
)
)
X[, nfeatures := as.integer(nfeatures)]

## Sample specific set of features # Not optimal, as it is a bit slow for noSamp -------
setkey(X, nfeatures)
Samp <- sapply(X = X$nfeatures, FUN = function(x) {
aa <- rep(NA, m)
aa[1:x] <- sample(x = 1:m, size = x)
aa
})
Samp <- t(apply(X = Samp, MARGIN = 2, FUN = sort, na.last = T))
Samp.list <- apply(X = Samp, MARGIN = 1, FUN = function(x) {
x[!is.na(x)]
})

X <- cbind(X, Samp)
X[, no := .N, by = mget(paste0("V", 1:m))] # Counting repetitions of the same sample

if (reduce_dim) {
isDup <- duplicated(X)
X[, features := Samp.list]
X <- X[!isDup, ]
} else {
X[, no := 1]
X[, features := Samp.list]
}

X[, paste0("V", 1:m) := NULL]
X[, ID := .I]

nms <- c("ID", "nfeatures", "features", "no")
setcolorder(X, nms)

## Add zero features and m features ----------
X_zero_all <- data.table(
ID = seq(X[, max(ID)] + 1, length.out = 2),
num_var = c(0, m),
comb = c(list(numeric(0)), list(1:m)),
no = 1
)
X <- rbindlist(list(X, X_zero_all))
setkey(X, nfeatures)

## Add number of combinations
X <- merge(x = X, y = DT0[, .(nfeatures, N, shapley_weight)], all.x = TRUE, on = "nfeatures")
nms <- c("ID", "features", "nfeatures", "N", "shapley_weight", "no")
setcolorder(X, nms)
X[, ID := .I]
X[nfeatures %in% c(0, m), `:=`(
shapley_weight = shapley_weight_inf_replacement,
N = 1
)]
# Sample specific set of features -------
data.table::setkey(X, nfeatures)
feature_sample <- sample_features_cpp(m, X[["nfeatures"]])

# Get number of occurences and duplicated rows-------
r <- helper_feature(m, feature_sample)
X[, no := r[["no"]]]
X[, is_duplicate := r[["is_duplicate"]]]
X[, ID := .I]

# Populate table and remove duplicated rows -------
X[, features := feature_sample]
if (reduce_dim && any(X[["is_duplicate"]])) {
X <- X[is_duplicate == FALSE]
X[, no := 1]
}
X[, is_duplicate := NULL]
nms <- c("ID", "nfeatures", "features", "no")
data.table::setcolorder(X, nms)

# Add shapley weight and number of combinations
X[, shapley_weight := shapley_weight_inf_replacement]
X[, N := 1]
X[between(nfeatures, 1, m - 1), ind := TRUE]
X[ind == TRUE, shapley_weight := p[nfeatures]]
X[ind == TRUE, N := n[nfeatures]]
X[, ind := NULL]

# Set column order and key table
nms <- c("ID", "features", "nfeatures", "N", "shapley_weight", "no")
data.table::setcolorder(X, nms)
data.table::setkey(X, nfeatures)
X[, ID := .I]
X[, N := as.integer(N)]

return(X)
}

#' @keywords internal
helper_feature <- function(m, feature_sample) {

x <- feature_matrix_cpp(feature_sample, m)
dt <- data.table::data.table(x)
cnms <- paste0("V", seq(m))
data.table::setnames(dt, cnms)
dt[, no := as.integer(.N), by = cnms]
dt[, is_duplicate := duplicated(dt)]
dt[, (cnms) := NULL]

return(dt)
}
21 changes: 4 additions & 17 deletions R/shapley.R
Expand Up @@ -7,24 +7,11 @@
#' @export
#'
#' @author Nikolai Sellereite
shapley_weights <- function(m, N, s) {
(m - 1) / (N * s * (m - s))
}

#' Calculate Shapley weights
#'
#' @param X data.table
#'
#' @return data.table
#'
#' @export
#'
#' @author Nikolai Sellereite
observation_weights <- function(X, m) {
X[-c(1, .N), weight := shapley_weights(m = m, N = N, s = nfeatures), ID]
X[c(1, .N), weight := 10^6]
shapley_weights <- function(m, N, s, weight_zero_m = 10^6) {

return(X)
x <- (m - 1) / (N * s * (m - s))
x[!is.finite(x)] <- weight_zero_m
x
}

#' Get weighted matrix
Expand Down
6 changes: 3 additions & 3 deletions man/feature_combinations.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 0 additions & 20 deletions man/observation_weights.Rd

This file was deleted.

2 changes: 1 addition & 1 deletion man/shapley_weights.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions src/RcppExports.cpp
Expand Up @@ -90,6 +90,18 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// sample_features_cpp
List sample_features_cpp(int m, IntegerVector nfeatures);
RcppExport SEXP _shapr_sample_features_cpp(SEXP mSEXP, SEXP nfeaturesSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< int >::type m(mSEXP);
Rcpp::traits::input_parameter< IntegerVector >::type nfeatures(nfeaturesSEXP);
rcpp_result_gen = Rcpp::wrap(sample_features_cpp(m, nfeatures));
return rcpp_result_gen;
END_RCPP
}
// observation_impute_cpp
NumericMatrix observation_impute_cpp(IntegerVector index_xtrain, IntegerVector index_s, NumericMatrix xtrain, NumericMatrix xtest, IntegerMatrix S);
RcppExport SEXP _shapr_observation_impute_cpp(SEXP index_xtrainSEXP, SEXP index_sSEXP, SEXP xtrainSEXP, SEXP xtestSEXP, SEXP SSEXP) {
Expand Down Expand Up @@ -139,6 +151,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_shapr_aicc_full_single_cpp", (DL_FUNC) &_shapr_aicc_full_single_cpp, 5},
{"_shapr_aicc_full_cpp", (DL_FUNC) &_shapr_aicc_full_cpp, 6},
{"_shapr_mahalanobis_distance_cpp", (DL_FUNC) &_shapr_mahalanobis_distance_cpp, 5},
{"_shapr_sample_features_cpp", (DL_FUNC) &_shapr_sample_features_cpp, 2},
{"_shapr_observation_impute_cpp", (DL_FUNC) &_shapr_observation_impute_cpp, 5},
{"_shapr_weight_matrix_cpp", (DL_FUNC) &_shapr_weight_matrix_cpp, 4},
{"_shapr_feature_matrix_cpp", (DL_FUNC) &_shapr_feature_matrix_cpp, 2},
Expand Down
24 changes: 24 additions & 0 deletions src/features.cpp
@@ -0,0 +1,24 @@
#include <Rcpp.h>
using namespace Rcpp;

//' @keywords internal
//'
//' @export
//'
// [[Rcpp::export]]
List sample_features_cpp(int m, IntegerVector nfeatures) {

int n = nfeatures.length();
List l(n);

for (int i = 0; i < n; i++) {

int s = nfeatures[i];
IntegerVector k = sample(m, s);
std::sort(k.begin(), k.end());
l[i] = k;

}

return l;
}

0 comments on commit f38bf2e

Please sign in to comment.