/
calculate_weight.R
66 lines (65 loc) · 3.4 KB
/
calculate_weight.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#' Calculated empirical density and weight based on variable split.
#'
#' This function calculate an empirical density of raw data based on variable split from Ceteris Paribus profiles. Then calculated weight for values generated by \code{DALEX::predict_profile()}, \code{DALEX::individual_profile()} or \code{ingredients::ceteris_paribus()}.
#'
#' @param profiles \code{data.frame} generated by \code{DALEX::predict_profile()}, \code{DALEX::individual_profile()} or \code{ingredients::ceteris_paribus()}
#' @param data \code{data.frame} with raw data to model
#' @param variable_split list generated by \code{vivo::calculate_variable_split()}
#'
#' @return Return an weight based on empirical density.
#'
#' @examples
#'
#' library("DALEX", warn.conflicts = FALSE, quietly = TRUE)
#' data(apartments)
#'
#' split <- vivo::calculate_variable_split(apartments,
#' variables = colnames(apartments),
#' grid_points = 101)
#'
#' library("randomForest", warn.conflicts = FALSE, quietly = TRUE)
#' apartments_rf_model <- randomForest(m2.price ~ construction.year + surface +
#' floor + no.rooms, data = apartments)
#'
#' explainer_rf <- explain(apartments_rf_model, data = apartmentsTest[,2:5],
#' y = apartmentsTest$m2.price)
#'
#' new_apartment <- data.frame(construction.year = 1998, surface = 88, floor = 2L, no.rooms = 3)
#'
#' profiles <- predict_profile(explainer_rf, new_apartment)
#'
#' library("vivo")
#' calculate_weight(profiles, data = apartments[, 2:5], variable_split = split)
#'
#'
#' @export
#'
calculate_weight <- function(profiles,
data,
variable_split){
if (!(c("ceteris_paribus_explainer") %in% class(profiles)) & !(c("predict_profile") %in% class(profiles)))
stop("The calculate_weight() function requires an object created with predict_profile() or ceteris_paribus() function.")
if (!(c("list") %in% class(variable_split)))
stop("The calculate_weight() function requires an object created with calculate_variable_split() function.")
if (!(c("data.frame") %in% class(data)))
stop("The calculate_weight() function requires a data.frame.")
cut_range <- lapply(unique(colnames(data)), function(x){
data.frame(table(cut(data[, as.vector(as.character(x))],
unique(c(min(data[, as.vector(as.character(x))]),
variable_split[[as.character(x)]],
max(data[, as.vector(as.character(x))]))), include.lowest = TRUE))/nrow(data))})
weight_range <- lapply(unique(colnames(data)), function(x){
data.frame("Var1" = cut(profiles[profiles$`_vname_` == x, as.vector(as.character(x))],
unique(c(min(data[, as.vector(as.character(x))]),
variable_split[[as.character(x)]],
max(data[, as.vector(as.character(x))]))), include.lowest = TRUE),
"Value" = profiles[profiles$`_vname_` == x, as.vector(as.character(x))])
})
names(cut_range) <- as.vector(unique(colnames(data)))
names(weight_range) <- as.vector(unique(colnames(data)))
weight <- lapply(as.vector(unique(colnames(data))), function(x){
unname(unlist(merge(weight_range[[x]], cut_range[[x]], by = "Var1", all.x = TRUE, sort = FALSE)["Freq"]))
})
names(weight) <- as.vector(unique(colnames(data)))
weight
}