-
Notifications
You must be signed in to change notification settings - Fork 0
/
eSHAP_plot.R
241 lines (220 loc) · 10.5 KB
/
eSHAP_plot.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
#' @title Enhanced SHAP Analysis for Binary Classification Models
#'
#' @description
#' The SHAP plot for classification models is a visualization tool that uses the Shapley value, an approach from cooperative game theory, to compute feature contributions for single predictions. The Shapley value fairly distributes the difference of the instance’s prediction and the datasets average prediction among the features. This method is available from the iml package.
#'
#' @param sample.size numeric, default to 30. The larger the value, the slower but more accurate the estimate of SHAP values
#' @param seed numeric, an integer for reproducibility. Default to 246
#' @param task mlr3 task object for binary classification
#' @param trained_model mlr3 trained learner object
#' @param splits mlr3 object defining data splits for train and test sets
#' @param subset numeric, what percentage of the instances to use from 0 to 1 where 1 means all
#'
#' @importFrom magrittr %>%
#' @importFrom ggplot2 ggplot aes geom_violin geom_line coord_flip geom_jitter position_jitter scale_shape_manual labs scale_colour_gradient2 geom_text theme element_blank geom_hline element_text element_line ylim
#' @export
#'
#' @return
#' A list containing:
#' \item{shap_plot}{An enhanced SHAP plot with user interactive elements.}
#' \item{shap_Mean_wide}{A matrix of SHAP values.}
#' \item{shap_Mean}{A data.table with aggregated SHAP values.}
#' \item{shap}{Raw SHAP values.}
#' \item{shap_pred_plot}{A plot depicting SHAP values versus predicted probabilities.}
#'
#' @examples
#' \donttest{
#' library("explainer")
#' seed <- 246
#' set.seed(seed)
#' # Load necessary packages
#' if (!requireNamespace("mlbench", quietly = TRUE)) stop("mlbench not installed.")
#' if (!requireNamespace("mlr3learners", quietly = TRUE)) stop("mlr3learners not installed.")
#' if (!requireNamespace("ranger", quietly = TRUE)) stop("ranger not installed.")
#' # Load BreastCancer dataset
#' utils::data("BreastCancer", package = "mlbench")
#' target_col <- "Class"
#' positive_class <- "malignant"
#' mydata <- BreastCancer[, -1]
#' mydata <- na.omit(mydata)
#' sex <- sample(c("Male", "Female"), size = nrow(mydata), replace = TRUE)
#' mydata$age <- as.numeric(sample(seq(18,60), size = nrow(mydata), replace = TRUE))
#' mydata$sex <- factor(sex, levels = c("Male", "Female"), labels = c(1, 0))
#' maintask <- mlr3::TaskClassif$new(
#' id = "my_classification_task",
#' backend = mydata,
#' target = target_col,
#' positive = positive_class
#' )
#' splits <- mlr3::partition(maintask)
#' mylrn <- mlr3::lrn("classif.ranger", predict_type = "prob")
#' mylrn$train(maintask, splits$train)
#' SHAP_output <- eSHAP_plot(
#' task = maintask,
#' trained_model = mylrn,
#' splits = splits,
#' sample.size = 2, # also 30 or more
#' seed = seed,
#' subset = 0.02 # up to 1
#' )
#' }
#'
#' @references
#' Zargari Marandi, R., 2024. ExplaineR: an R package to explain machine learning models. Bioinformatics advances, 4(1), p.vbae049.
#' Molnar C, Casalicchio G, Bischl B. iml: An R package for interpretable machine learning. Journal of Open Source Software. 2018 Jun 27;3(26):786.
#'
#' @seealso
#' [eSHAP_plot_reg()]
#'
#' @keywords internal
#' @family classification
#' @family SHAP
eSHAP_plot <- function(task,
trained_model,
splits,
sample.size = 30,
seed = 246,
subset = 1) {
cluster <- NULL
correct_prediction <- NULL
feature <- NULL
f_val <- NULL
fval <- NULL
mean_absolute_shap <- NULL
mean_phi <- NULL
Phi <- NULL
pred_class <- NULL
pred_prob <- NULL
prediction_correctness <- NULL
response <- NULL
sample_num <- NULL
truth <- NULL
unscaled_f_val <- NULL
set.seed(seed) # set seed for reproducibility
mydata <- task$data()
mydata <- as.data.frame(mydata)
X <- mydata[which(names(mydata[splits$train,]) != task$target_names)]
model <- iml::Predictor$new(trained_model, data = X, y = mydata[, task$target_names])
# randomly subset the target variable and the corresponding rows
n <- round(subset * length(splits$test))
target_index <- sample(splits$test, size = n, replace = FALSE)
mydata <- mydata[target_index, ]
# do the prediction for the test set
pred_results <- trained_model$predict(task,target_index)
# the test set based on the data split is used to calculate SHAP values
test_set <- as.data.frame(mydata)
feature_names <- colnames(X)
nfeats <- length(feature_names)
# save the predicted probability for the positive class
pred_prob <- pred_results$prob[,1]
# mark which samples were correctly predicted and which samples were not
predicted_correct <- mydata$Class==pred_results$response
# initialize the results list.
shap_values <- vector("list", nrow(test_set))
for (i in seq_along(shap_values)) {
# set.seed(seed)
shap_values[[i]] <- iml::Shapley$new(model,
x.interest = test_set[i,feature_names],
sample.size = sample.size)$results
shap_values[[i]]$sample_num <- i # identifier to track our instances.
shap_values[[i]]$predcorrectness <- predicted_correct[i]
shap_values[[i]]$pred_prob <- pred_prob[i]
shap_values[[i]]$pred_class <- pred_results$response[i]
}
data_shap_values <- dplyr::bind_rows(shap_values) # collapse the list.
shap <- data_shap_values[which(data_shap_values$class==task$positive),]
total_reps <- nrow(shap)/nfeats
mean_phi <- rep(0,nfeats)
indiv_phi <- rep(0,nfeats)
f_val_lst <- rep(0,nfeats)
indiv_correctness <- rep(0,nfeats)
pred_prob_rep <- rep(0,nfeats)
pred_class_rep <- rep(0,nfeats)
feature_values <- gsub(".*=",'',shap$feature.value)
shap$feature.value <- as.numeric(feature_values)
for (i in 1:nfeats){
mean_phi[i] = mean(abs(shap$phi[seq(i,nrow(shap),nfeats)[predicted_correct]]))# only correct predictions to calculate the means
indiv_phi[i] = list(shap$phi[seq(i,nrow(shap),nfeats)])
f_val_lst[i] = list(feature_values[seq(i,nrow(shap),nfeats)])
indiv_correctness[i] = list(shap$predcorrectness[seq(i,nrow(shap),nfeats)])
pred_prob_rep[i] = list(shap$pred_prob[seq(i,nrow(shap),nfeats)])
pred_class_rep[i] = list(shap$pred_class[seq(i,nrow(shap),nfeats)])
}
# test_set.nolab[,task$target_names:=NULL]
mydata[,task$target_names] <- NULL
# get the column names of the data frame
cols <- colnames(mydata)
# loop through each column
for (col in cols) {
# check if the column is numeric
if (!is.numeric(mydata[[col]])) {
# convert non-numeric columns to numeric
mydata[[col]] <- as.numeric(mydata[[col]])
}
}
# store feature values
unscaled_f_val_lst <- f_val_lst
# apply transformation for visualization
for (i in 1:length(f_val_lst)){
unscaled_f_val_lst[[i]] <- mydata[,i] # not scaled
f_val_lst[[i]] <- range01(mydata[,i]) # normalization
}
(unscaled_f_val = as.numeric(unlist(unscaled_f_val_lst)))
(f_val = as.numeric(unlist(f_val_lst)))
(Phi = unlist(indiv_phi))
shap_Mean <- data.table::data.table(feature=rep(feature_names,each=total_reps),
mean_phi = rep(mean_phi,each=total_reps),
Phi = Phi,
f_val = f_val,
unscaled_f_val = unscaled_f_val,
sample_num = rep(1:nrow(test_set),length(feature_names)),
correct_prediction = unlist(indiv_correctness),
pred_prob = unlist(pred_prob_rep),
pred_class = unlist(pred_class_rep))
shap_Mean_wide <- data.table::dcast(shap_Mean, sample_num ~ feature, value.var="Phi")
shap_Mean$correct_prediction <- factor(shap_Mean$correct_prediction, levels = c(FALSE, TRUE), labels = c("Incorrect","Correct"))
shap_plot <- shap_Mean %>%
mutate(feature = forcats::fct_reorder(feature, mean_phi)) %>%
ggplot(aes(x = feature, y = Phi, color = f_val))+
geom_violin(colour = "grey") +
geom_line(aes(group = sample_num), alpha = 0.1, size=0.2) +
coord_flip() +
geom_jitter(aes(shape=correct_prediction, text = paste("Feature: ", feature,
"<br>Unscaled feature value: ", unscaled_f_val,
"<br>SHAP value: ", Phi,
"<br>Prediction correctness: ", correct_prediction,
"<br>Predicted probability: ", pred_prob,
"<br>Predicted class: ", pred_class)),
alpha = 0.6, size=1.5, position=position_jitter(width=0.2, height=0)) +
scale_shape_manual(values=c(4, 19), guide = FALSE)+
# scale_color_manual(values=c("black","grey")) +
labs(shape = "model prediction") +
scale_colour_gradient2(low="blue" ,mid="green", high="red", midpoint=0.5, breaks=c(0,1), labels=c("Low","High")) +
guides(shape = ggplot2::guide_legend(override.aes = list(fill = "black", color = "black"))) +
geom_text(aes(x = feature, y=-Inf, label = sprintf("%.3f", mean_phi)), hjust = -0.2, alpha = 0.7, color = "black") +
theme(axis.line.y = element_blank(), axis.ticks.y = element_blank(),
legend.position="right") +
geom_hline(yintercept = 0, color = "grey") + # the vertical line
labs(y = "SHAP decision plot - test set", x = "features", color = "feature values scaled\n to [low=0 high=1]") +
theme(text = element_text(size = 10, family="Helvetica"),
# Remove panel border
panel.border = element_blank(),
# Remove panel grid lines
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
# Remove panel background
panel.background = element_blank(),
# Add axis line
axis.line = element_line(colour = "grey"),
legend.key.width = grid::unit(2,"mm")) +
ylim(min(shap_Mean$Phi)-0.05, max(shap_Mean$Phi)+0.05)
# Convert ggplot to Plotly
shap_plot <- ggplotly(shap_plot, tooltip="text")
# Additional plot to show SHAP values vs. predicted probabilities
shap_pred_plot <- shap_Mean %>%
ggplot(aes(x = Phi, y = pred_prob, shape=pred_class)) +
geom_point() +
geom_smooth(method = "loess", se = FALSE) +
labs(x = "SHAP value", y = "Predicted probability")
return(list(shap_plot, shap_Mean_wide, shap_Mean, shap, shap_pred_plot))
}