/
caret-tidiers.R
132 lines (126 loc) · 3.83 KB
/
caret-tidiers.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#' @templateVar class confusionMatrix
#' @template title_desc_tidy
#'
#' @param x An object of class `confusionMatrix` created by a call to
#' [caret::confusionMatrix()].
#' @param by_class Logical indicating whether or not to show performance
#' measures broken down by class. Defaults to `TRUE`. When `by_class = FALSE`
#' only returns a tibble with accuracy, kappa, and McNemar statistics.
#' @template param_unused_dots
#'
#' @evalRd return_tidy(
#' "term",
#' "estimate",
#' "conf.low",
#' "conf.high",
#' "class",
#' p.value = "P-value for accuracy and kappa statistics."
#' )
#'
#' @examplesIf rlang::is_installed("caret")
#'
#' # load libraries for models and data
#' library(caret)
#'
#' set.seed(27)
#'
#' # generate data
#' two_class_sample1 <- as.factor(sample(letters[1:2], 100, TRUE))
#' two_class_sample2 <- as.factor(sample(letters[1:2], 100, TRUE))
#'
#' two_class_cm <- confusionMatrix(
#' two_class_sample1,
#' two_class_sample2
#' )
#'
#' # summarize model fit with tidiers
#' tidy(two_class_cm)
#' tidy(two_class_cm, by_class = FALSE)
#'
#' # multiclass example
#' six_class_sample1 <- as.factor(sample(letters[1:6], 100, TRUE))
#' six_class_sample2 <- as.factor(sample(letters[1:6], 100, TRUE))
#'
#' six_class_cm <- confusionMatrix(
#' six_class_sample1,
#' six_class_sample2
#' )
#'
#' # summarize model fit with tidiers
#' tidy(six_class_cm)
#' tidy(six_class_cm, by_class = FALSE)
#'
#' @aliases caret_tidiers confusionMatrix_tidiers
#' @export
#' @seealso [tidy()], [caret::confusionMatrix()]
tidy.confusionMatrix <- function(x, by_class = TRUE, ...) {
cm <- as.list(x$overall)
nms_cm <- stringr::str_to_lower(c(names(cm)[1:2], "McNemar"))
if (by_class) {
# case when only 2 classes
if (!inherits(x$byClass, "matrix")) {
classes <-
x$byClass %>%
as.data.frame() %>%
rename_at(1, ~"value") %>%
tibble::rownames_to_column("var") %>%
mutate(var = stringr::str_to_lower(gsub(" ", "_", var)))
terms <- c(nms_cm, classes$var)
class <- c(rep(NA_character_, 3), rep(x$positive, length(terms) - 3))
estimates <- c(cm$Accuracy, cm$Kappa, NA, classes$value)
conf.low <- c(cm$AccuracyLower, rep(NA, length(terms) - 1))
conf.high <- c(cm$AccuracyUpper, rep(NA, length(terms) - 1))
p.value <- c(
cm$AccuracyPValue, NA, cm$McnemarPValue,
rep(NA, length(terms) - 3)
)
} else {
# case when there are more than 2 classes
classes <-
x$byClass %>%
as.data.frame() %>%
tibble::rownames_to_column("class") %>%
pivot_longer(
cols = c(dplyr::everything(), -class),
names_to = "var",
values_to = "value"
) %>%
mutate(
var = stringr::str_to_lower(gsub(" ", "_", var)),
class = gsub("Class: ", "", class)
)
terms <- c(nms_cm, classes$var)
class <- c(rep(NA_character_, 3), classes$class)
estimates <- c(cm$Accuracy, cm$Kappa, NA, classes$value)
conf.low <- c(cm$AccuracyLower, rep(NA, length(terms) - 1))
conf.high <- c(cm$AccuracyUpper, rep(NA, length(terms) - 1))
p.value <- c(
cm$AccuracyPValue, NA, cm$McnemarPValue,
rep(NA, length(terms) - 3)
)
}
df <- tibble(
term = terms,
class = class,
estimate = estimates,
conf.low = conf.low,
conf.high = conf.high,
p.value = p.value
)
} else {
# only show alpha, kappa, and mcnamara, when show_class = FALSE
terms <- c(nms_cm)
estimates <- c(cm$Accuracy, cm$Kappa, NA)
conf.low <- c(cm$AccuracyLower, NA, NA)
conf.high <- c(cm$AccuracyUpper, NA, NA)
p.value <- c(cm$AccuracyPValue, NA, cm$McnemarPValue)
df <- tibble(
term = terms,
estimate = estimates,
conf.low = conf.low,
conf.high = conf.high,
p.value = p.value
)
}
as_tidy_tibble(df)
}