/
fit_xy.R
115 lines (101 loc) · 3.17 KB
/
fit_xy.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#' Fit a nested model to a dataset using an xy interface.
#'
#' [generics::fit_xy()] method for nested models. This should not be
#' called directly and instead should be called by
#' [workflows::fit.workflow()].
#'
#' @param object An `nested_model` object (see [nested()]).
#' @param x A data frame of predictors.
#' @param y A data frame of outcome data.
#' @param case_weights An optional vector of case weights. Passed into
#' [parsnip::fit.model_spec()].
#' @param control A [parsnip::control_parsnip()] object. Passed into
#' [parsnip::fit.model_spec()].
#' @param ... Passed into [parsnip::fit.model_spec()]. Currently unused.
#'
#' @returns A `nested_model_fit` object with several elements:
#' * `spec`: The model specification object (the inner model of the
#' nested model object)
#' * `fit`: A tibble containing the model fits and the nests that they
#' correspond to.
#' * `inner_names`: A character vector of names, used to help with
#' nesting the data during predictions.
#'
#' @seealso [parsnip::fit.model_spec()] [parsnip::model_fit]
#'
#' @examplesIf rlang::is_installed("workflows")
#'
#' library(dplyr)
#' library(parsnip)
#' library(recipes)
#' library(workflows)
#'
#' data <- filter(example_nested_data, id %in% 11:20)
#'
#' model <- linear_reg() %>%
#' set_engine("lm") %>%
#' nested()
#'
#' recipe <- recipe(data, z ~ x + y + id) %>%
#' step_nest(id)
#'
#' wf <- workflow() %>%
#' add_recipe(recipe) %>%
#' add_model(model)
#'
#' fit(wf, data)
#'
#' @importFrom generics fit_xy
#'
#' @export fit_xy.nested_model
#' @export
fit_xy.nested_model <- function(object, x, y, case_weights = NULL,
control = parsnip::control_parsnip(), ...) {
x <- check_df(x, "x")
if (is.vector(y)) {
y <- data.frame(outcome = y)
} else {
y <- check_df(y, "y")
}
model <- extract_inner_model(object)
parallel <- object$eng_args$allow_par
pkgs <- object$eng_args$pkgs
if (!is.null(object$args)) {
model <- pass_down_args(model, object)
}
if (!".nest_id" %in% colnames(x)) {
cli::cli_abort(c(
"{.arg x} does not contain a \".nest_id\" column.",
"i" = "Try using {.fun step_nest}."
))
}
y$.nest_id <- x$.nest_id
nested_colname <- get_name(".data", c(colnames(x), colnames(y)))
nested_x <- x %>%
tidyr::nest(!!nested_colname := -".nest_id") %>%
dplyr::ungroup()
nested_y <- y %>%
tidyr::nest(!!nested_colname := -".nest_id") %>%
dplyr::ungroup()
`%op%` <- get_operator(parallel, model)
rlang::local_options(doFuture.rng.onMisuse = "ignore")
fits <- foreach::foreach(
x = nested_x[[nested_colname]],
y = nested_y[[nested_colname]],
.export = "safe_fit_xy",
.packages = unique(c(pkgs, generics::required_pkgs(model)))
) %op% {
safe_fit_xy(model, x, y, case_weights = case_weights, control = control, ...)
}
cols <- c(
colnames(purrr::compact(nested_x[[nested_colname]])[[1]]),
colnames(purrr::compact(nested_y[[nested_colname]])[[1]])
)
fit <- nested_x[, names(nested_x) != nested_colname]
fit$.model_fit <- fits
new_nested_model_fit(spec = model, fit = fit, inner_names = cols)
}
#' @noRd
safe_fit_xy <- function(...) {
try(fit_xy(...))
}