-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
846dddd
commit 4621854
Showing
9 changed files
with
243 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#' @describeIn causal_mod Construct a `causal_mod` with minimal checks | ||
#' @export | ||
new_causal_mod <- function(x = list(), fitted = double(0), idx = seq_along(fitted)) { | ||
if (is.atomic(x)) { | ||
cli_abort("{.arg x} cannot be an atomic type.") | ||
} | ||
|
||
names(fitted) <- NULL | ||
vctrs::new_vctr(fitted[idx], model=x, idx=idx, class="causal_mod") | ||
} | ||
|
||
#' Construct a fitted model column | ||
#' | ||
#' The `causal_mod` class acts like a vector of fitted values, but it also | ||
#' stores the fitted model for later predictions and summaries, and keeps track | ||
#' of observations that have been dropped (e.g. due to missingness) or subsetted. | ||
#' | ||
#' @param x | ||
#' * For `causal_mod()` and `new_causal_mod()`: A fitted model object. | ||
#' For `causal_mod()` this should support the [fitted()] generic. | ||
#' * For `is_causal_mod()`: An object to test | ||
#' @param idx A set of indices that connect observations fed into the model with | ||
#' fitted values. Should contain values from 1 to the number of fitted values, | ||
#' and may contain `NA` values for observations with no corresponding fitted | ||
#' value (such as those with missing data). Defaults to a sequence over the | ||
#' fitted values, with `NA`s determined by [na.action()]. | ||
#' @param fitted A vector of fitted values. Extracted automatically from the | ||
#' model object in `causal_mod()`. | ||
#' | ||
#' @returns A `causal_mod` object. For `is_causal_mod()`, a logical value. | ||
#' | ||
#' @examples | ||
#' m <- lm(yield ~ block + N*P*K, data=npk) | ||
#' causal_mod(m) | ||
#' | ||
#' d <- rbind(NA, npk) | ||
#' m_mis <- lm(yield ~ block + N*P*K, data=d) | ||
#' causal_mod(m_mis) # NA for missing value | ||
#' | ||
#' @order 1 | ||
#' @export | ||
causal_mod <- function(x, idx = NULL) { | ||
if (is.atomic(x)) { | ||
cli_abort("{.arg x} cannot be an atomic type.") | ||
} | ||
|
||
fitted <- stats::fitted(x) | ||
if (is.null(fitted)) { | ||
cli_abort("{.arg x} does not have a {.fn fitted} method.") | ||
} | ||
if (is.null(idx)) { | ||
nas <- na.action(x) | ||
if (is.null(nas)) { | ||
idx = seq_along(fitted) | ||
} else { | ||
idx = integer(length(fitted) + length(nas)) | ||
idx[nas] = NA_integer_ | ||
idx[-nas] = seq_along(fitted) | ||
} | ||
} | ||
|
||
new_causal_mod(x, fitted, idx) | ||
} | ||
|
||
#' @describeIn causal_mod Return `TRUE` if an object is an `causal_mod` list | ||
#' @export | ||
is_causal_mod <- function(x) { | ||
inherits(x, "causal_mod") | ||
} | ||
|
||
|
||
# printing | ||
#' @export | ||
format.causal_mod <- function(x, ...) { | ||
formatC(vctrs::vec_data(x)) | ||
} | ||
#' @export | ||
str.causal_mod <- function(object, max.level=2, ...) { | ||
NextMethod(max.level=max.level, ...) # nocov | ||
} | ||
|
||
# vctrs ------------------------------------------------------------------- | ||
|
||
#' @export | ||
`[.causal_mod` <- function(x, i) { | ||
out <- NextMethod() | ||
attr(out, "idx") <- attr(out, "idx")[i] | ||
out | ||
} | ||
|
||
#' @importFrom vctrs vec_ptype_abbr | ||
#' @method vec_ptype_abbr causal_mod | ||
#' @export | ||
vec_ptype_abbr.causal_mod <- function(x, ...) { | ||
"mod" # nocov | ||
} | ||
|
||
#' @importFrom vctrs vec_ptype2 | ||
#' @export | ||
vec_ptype2.double.causal_mod <- function(x, y, ...) double() # nocov | ||
#' @export | ||
vec_ptype2.causal_mod.double <- function(x, y, ...) double() | ||
#' @importFrom vctrs vec_cast | ||
#' @export | ||
vec_cast.double.causal_mod <- function(x, to, ...) vctrs::vec_data(x) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# causal_mod printing | ||
|
||
Code | ||
print(x) | ||
Output | ||
<causal_mod[24]> | ||
[1] 50.89 58.33 51.82 55.06 65.1 55.7 53.33 55.67 58.99 59.02 68.42 56.66 | ||
[13] 57.78 48.38 46.01 48.34 54.82 48.32 51.56 47.39 57.38 60.65 53.22 54.15 | ||
|
||
--- | ||
|
||
Code | ||
str(x) | ||
Output | ||
mod [1:24] 50.89, 58.33, 51.82, 55.06, 65.1, 55.7, 53.33, 55.67, 58.99, 59... | ||
@ model:List of 13 | ||
..- attr(*, "class")= chr "lm" | ||
@ idx : int [1:24] 1 2 3 4 5 6 7 8 9 10 ... | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
|
||
test_that("causal_mod constructor", { | ||
m <- lm(yield ~ block + N*P*K, data=npk) | ||
x <- causal_mod(m) | ||
expect_s3_class(x, "causal_mod") | ||
expect_type(x, "double") | ||
expect_true(is_causal_mod(x)) | ||
|
||
expect_error(causal_mod(5), "atomic") | ||
expect_error(new_causal_mod(5), "atomic") | ||
expect_error(causal_mod(list()), "fitted()") | ||
|
||
d = rbind(NA, npk, NA, npk) | ||
m <- lm(yield ~ block + N*P*K, data=d) | ||
x <- causal_mod(m) | ||
expect_equal(which(is.na(x)), unname(c(na.action(m)))) | ||
}) | ||
|
||
test_that("causal_mod conversion", { | ||
m <- lm(yield ~ block + N*P*K, data=npk) | ||
x <- causal_mod(m) | ||
|
||
expect_type(as.double(x), "double") | ||
expect_type(c(x, double()), "double") | ||
expect_type(c(double(), x), "double") | ||
}) | ||
|
||
test_that("causal_mod slicing", { | ||
m <- lm(yield ~ block + N*P*K, data=npk) | ||
x <- causal_mod(m) | ||
|
||
expect_equal(as.double(x[4:2]), unname(fitted(m)[4:2])) | ||
expect_equal(attr(x[4:2], "idx"), 4:2) | ||
}) | ||
|
||
|
||
test_that("causal_mod printing", { | ||
m <- lm(yield ~ block + N*P*K, data=npk) | ||
x <- causal_mod(m) | ||
|
||
expect_snapshot(print(x)) | ||
expect_snapshot(str(x)) | ||
}) | ||
|