/
ceteris_paribus.R
198 lines (182 loc) · 8.86 KB
/
ceteris_paribus.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#' Ceteris Paribus Profiles aka Individual Variable Profiles
#'
#' This explainer works for individual observations.
#' For each observation it calculates Ceteris Paribus Profiles for selected variables.
#' Such profiles can be used to hypothesize about model results if selected variable is changed.
#' For this reason it is also called 'What-If Profiles'.
#'
#' Find more details in \href{https://ema.drwhy.ai/ceterisParibus.html}{Ceteris Paribus Chapter}.
#'
#' @param x an explainer created with the \code{DALEX::explain()} function, or a model to be explained.
#' @param data validation dataset. It will be extracted from \code{x} if it's an explainer
#' NOTE: It is best when target variable is not present in the \code{data}
#' @param predict_function predict function. It will be extracted from \code{x} if it's an explainer
#' @param new_observation a new observation with columns that corresponds to variables used in the model
#' @param y true labels for \code{new_observation}. If specified then will be added to ceteris paribus plots.
#' NOTE: It is best when target variable is not present in the \code{new_observation}
#' @param variables names of variables for which profiles shall be calculated.
#' Will be passed to \code{\link{calculate_variable_split}}.
#' If NULL then all variables from the validation data will be used.
#' @param ... other parameters
#' @param variable_splits named list of splits for variables, in most cases created with \code{\link{calculate_variable_split}}.
#' If NULL then it will be calculated based on validation data available in the \code{explainer}.
#' @param grid_points maximum number of points for profile calculations. Note that the finaln number of points may be lower than \code{grid_points}, eg. if there is not enough unique values for a given variable. Will be passed to \code{\link{calculate_variable_split}}.
#' @param label name of the model. By default it's extracted from the \code{class} attribute of the model
#' @param variable_splits_type how variable grids shall be calculated? Use "quantiles" (default) for percentiles or "uniform" to get uniform grid of points
#' @param variable_splits_with_obs if \code{TRUE} then all values in \code{new_observation} will be included in \code{variable_splits}
#'
#' @references Explanatory Model Analysis. Explore, Explain, and Examine Predictive Models. \url{https://ema.drwhy.ai/}
#'
#' @return an object of the class \code{ceteris_paribus_explainer}.
#'
#' @examples
#' library("DALEX")
#' library("ingredients")
#' titanic_small <- select_sample(titanic_imputed, n = 500, seed = 1313)
#'
#' # build a model
#' model_titanic_glm <- glm(survived ~ gender + age + fare,
#' data = titanic_small,
#' family = "binomial")
#'
#' explain_titanic_glm <- explain(model_titanic_glm,
#' data = titanic_small[,-8],
#' y = titanic_small[,8])
#'
#' cp_rf <- ceteris_paribus(explain_titanic_glm, titanic_small[1,])
#' cp_rf
#'
#' plot(cp_rf, variables = "age")
#'
#' \donttest{
#' library("ranger")
#' model_titanic_rf <- ranger(survived ~., data = titanic_imputed, probability = TRUE)
#'
#'
#' explain_titanic_rf <- explain(model_titanic_rf,
#' data = titanic_imputed[,-8],
#' y = titanic_imputed[,8],
#' label = "ranger forest",
#' verbose = FALSE)
#'
#' # select few passangers
#' selected_passangers <- select_sample(titanic_imputed, n = 20)
#' cp_rf <- ceteris_paribus(explain_titanic_rf, selected_passangers)
#' cp_rf
#'
#' plot(cp_rf, variables = "age") +
#' show_observations(cp_rf, variables = "age") +
#' show_rugs(cp_rf, variables = "age", color = "red")
#'
#' }
#'
#' @export
#' @rdname ceteris_paribus
ceteris_paribus <- function(x, ...)
UseMethod("ceteris_paribus")
#' @export
#' @rdname ceteris_paribus
ceteris_paribus.explainer <- function(x,
new_observation,
y = NULL,
variables = NULL,
variable_splits = NULL,
grid_points = 101,
variable_splits_type = "quantiles",
...) {
# extracts model, data and predict function from the explainer
model <- x$model
data <- x$data
predict_function <- x$predict_function
label <- x$label
ceteris_paribus.default(x = model,
data = data,
predict_function = predict_function,
new_observation = new_observation,
y = y,
variables = variables,
variable_splits = variable_splits,
grid_points = grid_points,
label = label,
variable_splits_type = variable_splits_type,
...)
}
#' @export
#' @rdname ceteris_paribus
ceteris_paribus.default <- function(x,
data,
predict_function = predict,
new_observation,
y = NULL,
variables = NULL,
variable_splits = NULL,
grid_points = 101,
variable_splits_type = "quantiles",
variable_splits_with_obs = FALSE,
label = class(x)[1],
...) {
# here one can add model and data and new observation
# just in case only some variables are specified
# this will work only for data.frames
if (is.data.frame(data)) {
common_variables <- intersect(colnames(new_observation), colnames(data))
new_observation <- new_observation[, common_variables, drop = FALSE]
data <- data[,common_variables, drop = FALSE]
}
# calculate splits
# if splits are not provided, then will be calculated
if (is.null(variable_splits)) {
# need validation data from the explainer
if (is.null(data))
stop("The ceteris_paribus() function requires explainers created with specified 'data'.")
# need variables, if not provided, will be extracted from data
if (is.null(variables))
variables <- colnames(data)
variable_splits <- calculate_variable_split(data, variables = variables, grid_points = grid_points,
variable_splits_type = variable_splits_type,
new_observation = if(variable_splits_with_obs) new_observation else NA)
}
# calculate profiles
profiles <- calculate_variable_profile(new_observation,
variable_splits, x, predict_function)
# if there is more then one collumn with `_yhat_`
# then we need to convert it to a single collumn
col_yhat <- grep(colnames(profiles), pattern = "^_yhat_")
if (length(col_yhat) == 1) {
profiles$`_label_` <- label
# add points of interests
predictions <- predict_function(x, new_observation)
# if new_observation is a matrix then turn into data.frame. see #26
if (!is.data.frame(new_observation)) {
new_observation <- as.data.frame(new_observation)
}
new_observation$`_yhat_` <- predictions
new_observation$`_label_` <- label
new_observation$`_ids_` <- 1:nrow(new_observation)
if (!is.null(y)) new_observation$`_y_` <- y
} else {
# we need to recreate _yhat_ and create proper labels
new_profiles <- profiles[rep(1:nrow(profiles), times = length(col_yhat)), -col_yhat]
new_profiles$`_yhat_` <- unlist(c(profiles[,col_yhat]))
stripped_names <- gsub(colnames(profiles)[col_yhat], pattern = "_yhat_", replacement = "")
new_profiles$`_label_` <- paste0(label, rep(stripped_names, each = nrow(profiles)))
profiles <- new_profiles
# add points of interests
new_observation_ext <- new_observation[rep(1:nrow(new_observation), times = length(col_yhat)),]
predict_obs <- predict_function(x, new_observation)
# if new_observation is a matrix then turn into data.frame. see #26
if (!is.data.frame(new_observation_ext)) {
new_observation_ext <- as.data.frame(new_observation_ext)
}
new_observation_ext$`_yhat_` <- unlist(c(predict_obs))
new_observation_ext$`_label_` <- paste0(label, rep(stripped_names, each = nrow(new_observation)))
new_observation_ext$`_ids_` <- rep(1:nrow(new_observation), each = length(col_yhat))
# add y
if (!is.null(y)) new_observation_ext$`_y_` <- rep(y, times = length(col_yhat))
new_observation <- new_observation_ext
}
# prepare final object
attr(profiles, "observations") <- new_observation
class(profiles) <- c("ceteris_paribus_explainer", "data.frame")
profiles
}