/
glmnet-cv-glmnet-tidiers.R
109 lines (108 loc) · 2.66 KB
/
glmnet-cv-glmnet-tidiers.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#' @templateVar class cv.glmnet
#' @template title_desc_tidy
#'
#' @param x A `cv.glmnet` object returned from [glmnet::cv.glmnet()].
#' @template param_unused_dots
#'
#' @evalRd return_tidy(
#' "lambda",
#' "std.error",
#' "nzero",
#' conf.low = "lower bound on confidence interval for cross-validation
#' estimated loss.",
#' conf.high = "upper bound on confidence interval for cross-validation
#' estimated loss.",
#' estimate = "Median loss across all cross-validation folds for a given
#' lamdba"
#' )
#'
#' @examplesIf rlang::is_installed(c("glmnet", "ggplot2"))
#'
#' # load libraries for models and data
#' library(glmnet)
#'
#' set.seed(27)
#'
#' nobs <- 100
#' nvar <- 50
#' real <- 5
#'
#' x <- matrix(rnorm(nobs * nvar), nobs, nvar)
#' beta <- c(rnorm(real, 0, 1), rep(0, nvar - real))
#' y <- c(t(beta) %*% t(x)) + rnorm(nvar, sd = 3)
#'
#' cvfit1 <- cv.glmnet(x, y)
#'
#' tidy(cvfit1)
#' glance(cvfit1)
#'
#' library(ggplot2)
#'
#' tidied_cv <- tidy(cvfit1)
#' glance_cv <- glance(cvfit1)
#'
#' # plot of MSE as a function of lambda
#' g <- ggplot(tidied_cv, aes(lambda, estimate)) +
#' geom_line() +
#' scale_x_log10()
#' g
#'
#' # plot of MSE as a function of lambda with confidence ribbon
#' g <- g + geom_ribbon(aes(ymin = conf.low, ymax = conf.high), alpha = .25)
#' g
#'
#' # plot of MSE as a function of lambda with confidence ribbon and choices
#' # of minimum lambda marked
#' g <- g +
#' geom_vline(xintercept = glance_cv$lambda.min) +
#' geom_vline(xintercept = glance_cv$lambda.1se, lty = 2)
#' g
#'
#' # plot of number of zeros for each choice of lambda
#' ggplot(tidied_cv, aes(lambda, nzero)) +
#' geom_line() +
#' scale_x_log10()
#'
#' # coefficient plot with min lambda shown
#' tidied <- tidy(cvfit1$glmnet.fit)
#'
#' ggplot(tidied, aes(lambda, estimate, group = term)) +
#' scale_x_log10() +
#' geom_line() +
#' geom_vline(xintercept = glance_cv$lambda.min) +
#' geom_vline(xintercept = glance_cv$lambda.1se, lty = 2)
#'
#' @export
#' @family glmnet tidiers
#' @seealso [tidy()], [glmnet::cv.glmnet()]
tidy.cv.glmnet <- function(x, ...) {
with(
x,
tibble(
lambda = lambda,
estimate = cvm,
std.error = cvsd,
conf.low = cvlo,
conf.high = cvup,
nzero = nzero
)
)
}
#' @templateVar class cv.glmnet
#' @template title_desc_glance
#'
#' @inherit tidy.cv.glmnet params examples
#'
#' @evalRd return_glance("lambda.min", "lambda.1se", "nobs")
#'
#' @export
#' @seealso [glance()], [glmnet::cv.glmnet()]
#' @family glmnet tidiers
glance.cv.glmnet <- function(x, ...) {
as_glance_tibble(
lambda.min = x$lambda.min,
lambda.1se = x$lambda.1se,
nobs = stats::nobs(x),
na_types = "rri"
)
}