Skip to content

Commit

Permalink
Removed the need to use tidytext::reorder_within
Browse files Browse the repository at this point in the history
  • Loading branch information
moralapablo committed Jan 17, 2024
1 parent 9db180c commit 16fc312
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 23 deletions.
34 changes: 31 additions & 3 deletions R/nn2poly_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,32 @@ predict.nn2poly <- function(object, newdata, layers = NULL, ...) {
#' `x` is generated using `nn2poly()` with `keep_layers=TRUE`.
#'
#' @examples
#' # --- Single polynomial output ---
#' # Build a NN structure with random weights, with 2 (+ bias) inputs,
#' # 4 (+bias) neurons in the first hidden layer with "tanh" activation
#' # function, 4 (+bias) neurons in the second hidden layer with "softplus",
#' # and 2 "linear" output units
#'
#' weights_layer_1 <- matrix(rnorm(12), nrow = 3, ncol = 4)
#' weights_layer_2 <- matrix(rnorm(20), nrow = 5, ncol = 4)
#' weights_layer_3 <- matrix(rnorm(5), nrow = 5, ncol = 1)
#'
#' # Set it as a list with activation functions as names
#' nn_object = list("tanh" = weights_layer_1,
#' "softplus" = weights_layer_2,
#' "linear" = weights_layer_3)
#'
#' # Obtain the polynomial representation (order = 3) of that neural network
#' final_poly <- nn2poly(nn_object, max_order = 3)
#'
#' # Plot all the coefficients, one plot per output unit
#' plot(final_poly)
#'
#' # Plot only the 5 most important coeffcients (by absolute magnitude)
#' # one plot per output unit
#' plot(final_poly, n = 5)
#'
#' # --- Multiple output polynomials ---
#' # Build a NN structure with random weights, with 2 (+ bias) inputs,
#' # 4 (+bias) neurons in the first hidden layer with "tanh" activation
#' # function, 4 (+bias) neurons in the second hidden layer with "softplus",
Expand Down Expand Up @@ -277,11 +303,13 @@ plot.nn2poly <- function(x, ..., n=NULL) {
scale_labels <- c("-")
}

# inspired by tidytext::reorder_within
new_x <- do.call(paste, c(list(all_df$name, sep = "___"), list(all_df$type)))
reorder_aux <- stats::reorder(new_x, all_df$value, FUN = mean, decreasing = TRUE)


plot_all <- ggplot2::ggplot(all_df,
ggplot2::aes(x = tidytext::reorder_within(x = .data$name,
by = -.data$value,
within = .data$type),
ggplot2::aes(x = reorder_aux,
y = .data$value,
fill = .data$sign)) +
ggplot2::geom_bar(stat = "identity", colour = "black", alpha = 1) +
Expand Down
55 changes: 35 additions & 20 deletions tests/testthat/_snaps/nn2poly_methods_plot/top-null.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 16fc312

Please sign in to comment.