/
min_depth_distribution.R
132 lines (129 loc) · 7.28 KB
/
min_depth_distribution.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#' Calculate minimal depth distribution of a random forest
#'
#' Get minimal depth values for all trees in a random forest
#'
#' @param forest A randomForest or ranger object
#'
#' @return A data frame with the value of minimal depth for every variable in every tree
#'
#' @examples
#' min_depth_distribution(randomForest::randomForest(Species ~ ., data = iris, ntree = 100))
#' min_depth_distribution(ranger::ranger(Species ~ ., data = iris, num.trees = 100))
#'
#' @export
#' @import dplyr
#' @importFrom data.table rbindlist
min_depth_distribution <- function(forest){
tree <- NULL; `split var` <- NULL; depth <- NULL
forest_table <- forest2df(forest)
min_depth_frame <- dplyr::group_by(forest_table, tree, variable) %>%
dplyr::summarize(minimal_depth = min(depth), .groups = "drop")
min_depth_frame <- as.data.frame(min_depth_frame[!is.na(min_depth_frame$variable),])
return(min_depth_frame)
}
# Count the trees in which each variable had a given minimal depth
min_depth_count <- function(min_depth_frame){
tree <- NULL; minimal_depth <- NULL; variable <- NULL
mean_tree_depth <- dplyr::group_by(min_depth_frame, tree) %>%
dplyr::summarize(depth = max(minimal_depth) + 1) %>%
as.data.frame()
mean_tree_depth <- mean(mean_tree_depth$depth)
min_depth_count <- dplyr::group_by(min_depth_frame, variable, minimal_depth) %>%
dplyr::summarize(count = n(), .groups = "drop") %>%
as.data.frame()
occurrences <- stats::aggregate(count ~ variable, data = min_depth_count, sum)
colnames(occurrences)[2] <- "no_of_occurrences"
min_depth_count <-
data.frame(variable = occurrences$variable, minimal_depth = NA,
count = max(min_depth_frame$tree) - occurrences$no_of_occurrences) %>%
rbind(min_depth_count)
min_depth_count <- min_depth_count[order(min_depth_count$variable, min_depth_count$minimal_depth),]
rownames(min_depth_count) <- 1:nrow(min_depth_count)
return(list(min_depth_count, occurrences, mean_tree_depth))
}
# Get a data frame with means of minimal depth calculated using sample = c("all_trees", "top_trees", "relevant_trees")
get_min_depth_means <- function(min_depth_frame, min_depth_count_list, mean_sample){
.SD <- NULL; variable <- NULL
if(mean_sample == "all_trees"){
min_depth_count_list[[1]][is.na(min_depth_count_list[[1]]$minimal_depth), "minimal_depth"] <- min_depth_count_list[[3]]
min_depth_means <-
data.table::as.data.table(min_depth_count_list[[1]])[, stats::weighted.mean(.SD[["minimal_depth"]], .SD[["count"]]),
by = variable] %>% as.data.frame()
} else if(mean_sample == "top_trees"){
min_depth_count_list[[1]][is.na(min_depth_count_list[[1]]$minimal_depth), "count"] <-
min_depth_count_list[[1]][is.na(min_depth_count_list[[1]]$minimal_depth), "count"] -
min(min_depth_count_list[[1]][is.na(min_depth_count_list[[1]]$minimal_depth), "count"])
min_depth_count_list[[1]][is.na(min_depth_count_list[[1]]$minimal_depth), "minimal_depth"] <- min_depth_count_list[[3]]
min_depth_means <-
data.table::as.data.table(min_depth_count_list[[1]])[, stats::weighted.mean(.SD[["minimal_depth"]], .SD[["count"]]),
by = variable] %>% as.data.frame()
} else if(mean_sample == "relevant_trees"){
min_depth_means <- stats::aggregate(minimal_depth ~ variable, data = min_depth_frame, mean)
}
colnames(min_depth_means)[2] <- "mean_minimal_depth"
return(min_depth_means)
}
#' Plot the distribution of minimal depth in a random forest
#'
#' @param min_depth_frame A data frame output of min_depth_distribution function or a randomForest object
#' @param k The maximal number of variables with lowest mean minimal depth to be used for plotting
#' @param min_no_of_trees The minimal number of trees in which a variable has to be used for splitting to be used for plotting
#' @param mean_sample The sample of trees on which mean minimal depth is calculated, possible values are "all_trees", "top_trees", "relevant_trees"
#' @param mean_scale Logical: should the values of mean minimal depth be rescaled to the interval [0,1]?
#' @param mean_round The number of digits used for displaying mean minimal depth
#' @param main A string to be used as title of the plot
#'
#' @return A ggplot object
#'
#' @import ggplot2
#' @import dplyr
#'
#' @examples
#' forest <- randomForest::randomForest(Species ~ ., data = iris, ntree = 300)
#' plot_min_depth_distribution(min_depth_distribution(forest))
#'
#' @export
plot_min_depth_distribution <- function(min_depth_frame, k = 10, min_no_of_trees = 0,
mean_sample = "top_trees", mean_scale = FALSE, mean_round = 2,
main = "Distribution of minimal depth and its mean"){
minimal_depth <- NULL; mean_minimal_depth_label <- NULL; mean_minimal_depth <- NULL
if(any(c("randomForest", "ranger") %in% class(min_depth_frame))){
min_depth_frame <- min_depth_distribution(min_depth_frame)
}
min_depth_count_list <- min_depth_count(min_depth_frame)
min_depth_means <- get_min_depth_means(min_depth_frame, min_depth_count_list, mean_sample)
frame_with_means <- merge(min_depth_count_list[[1]], min_depth_means)
frame_with_means[is.na(frame_with_means$minimal_depth), "count"] <-
frame_with_means[is.na(frame_with_means$minimal_depth), "count"] -
min(frame_with_means[is.na(frame_with_means$minimal_depth), "count"])
if(mean_scale == TRUE){
frame_with_means$mean_minimal_depth <-
(frame_with_means$mean_minimal_depth - min(frame_with_means$mean_minimal_depth))/
(max(frame_with_means$mean_minimal_depth) - min(frame_with_means$mean_minimal_depth))
}
frame_with_means$mean_minimal_depth_label <-
(frame_with_means$mean_minimal_depth - min(frame_with_means$mean_minimal_depth))/
(max(frame_with_means$mean_minimal_depth) - min(frame_with_means$mean_minimal_depth)) *
max(min_depth_count_list[[2]]$no_of_occurrences)
variables <- min_depth_count_list[[2]][min_depth_count_list[[2]]$no_of_occurrences >= min_no_of_trees, "variable"]
frame_with_means <- frame_with_means[frame_with_means$variable %in% variables, ]
frame_with_means <-
within(frame_with_means, variable <-
factor(variable, levels = unique(frame_with_means[order(frame_with_means$mean_minimal_depth), "variable"])))
data <- frame_with_means[frame_with_means$variable %in% levels(frame_with_means$variable)[
1:min(k, length(unique(frame_with_means$variable)))], ]
data$variable <- droplevels(data$variable)
data_for_labels <- unique(data[, c("variable", "mean_minimal_depth", "mean_minimal_depth_label")])
data_for_labels$mean_minimal_depth <- round(data_for_labels$mean_minimal_depth, digits = mean_round)
plot <- ggplot(data, aes(x = variable, y = count)) +
geom_col(position = position_stack(reverse = TRUE), aes(fill = as.factor(minimal_depth))) + coord_flip() +
scale_x_discrete(limits = rev(levels(data$variable))) +
geom_errorbar(aes(ymin = mean_minimal_depth_label, ymax = mean_minimal_depth_label), linewidth = 1.5) +
xlab("Variable") + ylab("Number of trees") + guides(fill = guide_legend(title = "Minimal depth")) +
theme_bw() + geom_label(data = data_for_labels,
aes(y = mean_minimal_depth_label, label = mean_minimal_depth))
if(!is.null(main)){
plot <- plot + ggtitle(main)
}
return(plot)
}