Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 10 additions & 19 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Workflow derived from https://github.com/r-lib/actions/tree/master/examples
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
push:
Expand All @@ -18,7 +18,7 @@ jobs:
fail-fast: false
matrix:
config:
- {os: macOS-latest, r: 'release'}
- {os: macos-latest, r: 'release'}
- {os: windows-latest, r: 'release'}
- {os: ubuntu-latest, r: 'devel', http-user-agent: 'release'}
- {os: ubuntu-latest, r: 'release'}
Expand All @@ -29,30 +29,21 @@ jobs:
R_KEEP_PKG_SOURCE: yes

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3

- uses: r-lib/actions/setup-pandoc@v1
- uses: r-lib/actions/setup-pandoc@v2

- uses: r-lib/actions/setup-r@v1
- uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}
http-user-agent: ${{ matrix.config.http-user-agent }}
use-public-rspm: true

- uses: r-lib/actions/setup-r-dependencies@v1
- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: rcmdcheck
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v1

- name: Show testthat output
if: always()
run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true
shell: bash

- name: Upload check results
if: failure()
uses: actions/upload-artifact@main
- uses: r-lib/actions/check-r-package@v2
with:
name: ${{ runner.os }}-r${{ matrix.config.r }}-results
path: check
upload-snapshots: true
14 changes: 8 additions & 6 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: randomForestExplainer
Title: Explaining and Visualizing Random Forests in Terms of Variable Importance
Version: 0.10.1
Version: 0.10.2
Authors@R: c(
person("Aleksandra", "Paluszynska", email = "ola.paluszynska@gmail.com", role = c("aut")),
person("Przemyslaw", "Biecek", email = "przemyslaw.biecek@gmail.com", role = c("aut","ths")),
Expand All @@ -10,22 +10,24 @@ Description: A set of tools to help explain which variables are most important i
Depends: R (>= 3.0)
License: GPL
Encoding: UTF-8
LazyData: true
Imports:
data.table (>= 1.10.4),
dplyr (>= 0.7.1),
DT (>= 0.2),
GGally (>= 1.3.0),
ggplot2 (>= 2.2.1),
ggplot2 (>= 3.4.0),
ggrepel (>= 0.6.5),
randomForest (>= 4.6.12),
ranger(>= 0.9.0),
reshape2 (>= 1.4.2),
rmarkdown (>= 1.5)
rlang,
rmarkdown (>= 1.5),
tidyr
Suggests:
knitr,
MASS (>= 7.3.47),
testthat
VignetteBuilder: knitr
RoxygenNote: 7.1.0
RoxygenNote: 7.3.0
URL: https://github.com/ModelOriented/randomForestExplainer, https://modeloriented.github.io/randomForestExplainer/
Config/testthat/edition: 3
Config/Needs/website: ModelOriented/DrWhyTemplate
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import(ggplot2)
import(ggrepel)
importFrom(data.table,frankv)
importFrom(data.table,rbindlist)
importFrom(rlang,.data)
importFrom(stats,as.formula)
importFrom(stats,predict)
importFrom(stats,terms)
17 changes: 15 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
# randomForestExplainer 0.10.2

* Remove dependency on reshape2 in favour of tidyr (@olivroy, #33)

* Silence deprecation warnings from ggplot2 and dplyr (@olivroy, #29)

* Use testthat 3rd edition. (@olivroy, #33)

# randomForestExplainer 0.10.1

* Small tweaks to `explain_forest()`.

# randomForestExplainer 0.10.0

## New features
* Added support for ranger forests.
* Added support for unsupervised randomForest.
* Added tests for most functions.

## Bug fixes
* Fixed bug for explain_forest not finding templates.
* Added more intuitive error message for explain_forest when local importance is absent.
* Fixed bug for `explain_forest()` not finding templates.
* Added more intuitive error message for `explain_forest()` when local `importance` is absent.
20 changes: 8 additions & 12 deletions R/measure_importance.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ measure_min_depth <- function(min_depth_frame, mean_sample){
# randomForest
measure_no_of_nodes <- function(forest_table){
`split var` <- NULL
frame <- dplyr::group_by(forest_table, `split var`) %>% dplyr::summarize(n())
colnames(frame) <- c("variable", "no_of_nodes")
frame <- dplyr::group_by(forest_table, variable = `split var`) %>% dplyr::summarize(no_of_nodes = dplyr::n())
frame <- as.data.frame(frame[!is.na(frame$variable),])
frame$variable <- as.character(frame$variable)
return(frame)
Expand All @@ -21,8 +20,7 @@ measure_no_of_nodes <- function(forest_table){
# randomForest
measure_no_of_nodes_ranger <- function(forest_table){
splitvarName <- NULL
frame <- dplyr::group_by(forest_table, splitvarName) %>% dplyr::summarize(n())
colnames(frame) <- c("variable", "no_of_nodes")
frame <- dplyr::group_by(forest_table, variable = splitvarName) %>% dplyr::summarize(no_of_nodes = n())
frame <- as.data.frame(frame[!is.na(frame$variable),])
frame$variable <- as.character(frame$variable)
return(frame)
Expand Down Expand Up @@ -75,8 +73,7 @@ measure_vimp_ranger <- function(forest){
measure_no_of_trees <- function(min_depth_frame){
variable <- NULL
frame <- dplyr::group_by(min_depth_frame, variable) %>%
dplyr::summarize(count = n()) %>% as.data.frame()
colnames(frame)[2] <- "no_of_trees"
dplyr::summarize(no_of_trees = n()) %>% as.data.frame()
frame$variable <- as.character(frame$variable)
return(frame)
}
Expand All @@ -85,8 +82,7 @@ measure_no_of_trees <- function(min_depth_frame){
measure_times_a_root <- function(min_depth_frame){
variable <- NULL
frame <- min_depth_frame[min_depth_frame$minimal_depth == 0, ] %>%
dplyr::group_by(variable) %>% dplyr::summarize(count = n()) %>% as.data.frame()
colnames(frame)[2] <- "times_a_root"
dplyr::group_by(variable) %>% dplyr::summarize(times_a_root = n()) %>% as.data.frame()
frame$variable <- as.character(frame$variable)
return(frame)
}
Expand Down Expand Up @@ -329,13 +325,13 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
if(size_measure == "p_value"){
data$p_value <- cut(data$p_value, breaks = c(-Inf, 0.01, 0.05, 0.1, Inf),
labels = c("<0.01", "[0.01, 0.05)", "[0.05, 0.1)", ">=0.1"), right = FALSE)
plot <- ggplot(data, aes_string(x = x_measure, y = y_measure)) +
geom_point(aes_string(color = size_measure), size = 3) +
plot <- ggplot(data, aes(x = .data[[x_measure]], y = .data[[y_measure]])) +
geom_point(aes(color = .data[[size_measure]]), size = 3) +
geom_point(data = data_for_labels, color = "black", stroke = 2, aes(alpha = "top"), size = 3, shape = 21) +
geom_label_repel(data = data_for_labels, aes(label = variable), show.legend = FALSE) +
theme_bw() + scale_alpha_discrete(name = "variable", range = c(1, 1))
} else {
plot <- ggplot(data, aes_string(x = x_measure, y = y_measure, size = size_measure)) +
plot <- ggplot(data, aes(x = .data[[x_measure]], y = .data[[y_measure]], size = .data[[size_measure]])) +
geom_point(aes(colour = "black")) + geom_point(data = data_for_labels, aes(colour = "blue")) +
geom_label_repel(data = data_for_labels, aes(label = variable, size = NULL), show.legend = FALSE) +
scale_colour_manual(name = "variable", values = c("black", "blue"), labels = c("non-top", "top")) +
Expand All @@ -345,7 +341,7 @@ plot_multi_way_importance <- function(importance_frame, x_measure = "mean_min_de
}
}
} else {
plot <- ggplot(data, aes_string(x = x_measure, y = y_measure)) +
plot <- ggplot(data, aes(x = .data[[x_measure]], y = .data[[y_measure]])) +
geom_point(aes(colour = "black")) + geom_point(data = data_for_labels, aes(colour = "blue")) +
geom_label_repel(data = data_for_labels, aes(label = variable, size = NULL), show.legend = FALSE) +
scale_colour_manual(name = "variable", values = c("black", "blue"), labels = c("non-top", "top")) +
Expand Down
2 changes: 1 addition & 1 deletion R/min_depth_distribution.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ plot_min_depth_distribution <- function(min_depth_frame, k = 10, min_no_of_trees
plot <- ggplot(data, aes(x = variable, y = count)) +
geom_col(position = position_stack(reverse = TRUE), aes(fill = as.factor(minimal_depth))) + coord_flip() +
scale_x_discrete(limits = rev(levels(data$variable))) +
geom_errorbar(aes(ymin = mean_minimal_depth_label, ymax = mean_minimal_depth_label), size = 1.5) +
geom_errorbar(aes(ymin = mean_minimal_depth_label, ymax = mean_minimal_depth_label), linewidth = 1.5) +
xlab("Variable") + ylab("Number of trees") + guides(fill = guide_legend(title = "Minimal depth")) +
theme_bw() + geom_label(data = data_for_labels,
aes(y = mean_minimal_depth_label, label = mean_minimal_depth))
Expand Down
Loading