Skip to content

Commit

Permalink
causal_mod
Browse files Browse the repository at this point in the history
  • Loading branch information
CoryMcCartan committed Mar 26, 2023
1 parent 846dddd commit 4621854
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 6 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ Imports:
tibble (>= 3.0.0),
tidyselect,
vctrs,
pillar
pillar,
stats
Suggests:
dplyr,
testthat (>= 3.0.0)
Expand Down
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
# Generated by roxygen2: do not edit by hand

S3method("[",causal_idx)
S3method("[",causal_mod)
S3method("[",causal_tbl)
S3method("names<-",causal_tbl)
S3method(ctl_new_pillar,causal_tbl)
S3method(format,causal_idx)
S3method(format,causal_mod)
S3method(str,causal_mod)
S3method(tbl_format_header,causal_tbl)
S3method(tbl_format_setup,causal_tbl)
S3method(tbl_sum,causal_tbl)
S3method(vec_cast,causal_idx.list)
S3method(vec_cast,double.causal_mod)
S3method(vec_cast,list.causal_idx)
S3method(vec_ptype2,causal_idx.list)
S3method(vec_ptype2,causal_mod.double)
S3method(vec_ptype2,double.causal_mod)
S3method(vec_ptype2,list.causal_idx)
S3method(vec_ptype_abbr,causal_idx)
S3method(vec_ptype_abbr,causal_mod)
export("causal_cols<-")
export(add_causal_col)
export(as_causal_idx)
export(as_causal_tbl)
export(causal_cols)
export(causal_idx)
export(causal_mod)
export(causal_tbl)
export(get_causal_col)
export(get_outcome)
Expand All @@ -28,8 +36,10 @@ export(has_outcome)
export(has_panel)
export(has_treatment)
export(is_causal_idx)
export(is_causal_mod)
export(is_causal_tbl)
export(new_causal_idx)
export(new_causal_mod)
export(new_causal_tbl)
export(set_causal_col)
export(set_outcome)
Expand Down
4 changes: 4 additions & 0 deletions R/causal_cols.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ set_outcome <- function(data, outcome) {
data <- as_causal_tbl(data)
col <- single_col_name(enquo(outcome), data, "outcome")
causal_cols(data)$outcomes <- col
# coerce
data[[col]] <- vctrs::vec_cast(data[[col]], numeric(), x_arg=col)
# handle names
if (has_treatment(data)) {
names(causal_cols(data)$treatments)[1] <- col
}
Expand Down Expand Up @@ -139,7 +141,9 @@ set_treatment <- function(data, treatment, outcome = get_outcome()) {
data <- as_causal_tbl(data)
col <- single_col_name(enquo(treatment), data, "treatment")
causal_cols(data)$treatments <- col
# coerce
data[[col]] <- vctrs::vec_cast(data[[col]], numeric(), x_arg=col)
# handle names
if (has_outcome(data)) {
names(causal_cols(data)$treatments) <- get_outcome(data)
}
Expand Down
6 changes: 3 additions & 3 deletions R/causal_idx.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ new_causal_idx <- function(x = list()) {
#' (because of subsetting) are set to NA.
#'
#' @param x
#' * For `causal_idx()` and `new_causal_idx()`: A list of indices of type `causal_idx`
#' * For `causal_idx()` and `new_causal_idx()`: A list of indices
#' * For `is_causal_idx()`: An object to test
#' * For `as_causal_idx()`: An object to coerce
#'
#' @returns A `causal_idx` object.
#' @returns A `causal_idx` object. For `is_causal_idx()`, a logical value.
#'
#' @examples
#' idx <- causal_idx(list(2, c(1, NA, 3), 2))
Expand Down Expand Up @@ -82,7 +82,7 @@ vec_ptype_abbr.causal_idx <- function(x, ...) {

#' @importFrom vctrs vec_ptype2
#' @export
vec_ptype2.list.causal_idx <- function(x, y, ...) list()
vec_ptype2.list.causal_idx <- function(x, y, ...) list() # nocov
#' @export
vec_ptype2.causal_idx.list <- function(x, y, ...) list()
#' @importFrom vctrs vec_cast
Expand Down
105 changes: 105 additions & 0 deletions R/causal_mod.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#' @describeIn causal_mod Construct a `causal_mod` with minimal checks
#' @export
new_causal_mod <- function(x = list(), fitted = double(0), idx = seq_along(fitted)) {
if (is.atomic(x)) {
cli_abort("{.arg x} cannot be an atomic type.")
}

names(fitted) <- NULL
vctrs::new_vctr(fitted[idx], model=x, idx=idx, class="causal_mod")
}

#' Construct a fitted model column
#'
#' The `causal_mod` class acts like a vector of fitted values, but it also
#' stores the fitted model for later predictions and summaries, and keeps track
#' of observations that have been dropped (e.g. due to missingness) or subsetted.
#'
#' @param x
#' * For `causal_mod()` and `new_causal_mod()`: A fitted model object.
#' For `causal_mod()` this should support the [fitted()] generic.
#' * For `is_causal_mod()`: An object to test
#' @param idx A set of indices that connect observations fed into the model with
#' fitted values. Should contain values from 1 to the number of fitted values,
#' and may contain `NA` values for observations with no corresponding fitted
#' value (such as those with missing data). Defaults to a sequence over the
#' fitted values, with `NA`s determined by [na.action()].
#' @param fitted A vector of fitted values. Extracted automatically from the
#' model object in `causal_mod()`.
#'
#' @returns A `causal_mod` object. For `is_causal_mod()`, a logical value.
#'
#' @examples
#' m <- lm(yield ~ block + N*P*K, data=npk)
#' causal_mod(m)
#'
#' d <- rbind(NA, npk)
#' m_mis <- lm(yield ~ block + N*P*K, data=d)
#' causal_mod(m_mis) # NA for missing value
#'
#' @order 1
#' @export
causal_mod <- function(x, idx = NULL) {
if (is.atomic(x)) {
cli_abort("{.arg x} cannot be an atomic type.")
}

fitted <- stats::fitted(x)
if (is.null(fitted)) {
cli_abort("{.arg x} does not have a {.fn fitted} method.")
}
if (is.null(idx)) {
nas <- na.action(x)
if (is.null(nas)) {
idx = seq_along(fitted)
} else {
idx = integer(length(fitted) + length(nas))
idx[nas] = NA_integer_
idx[-nas] = seq_along(fitted)
}
}

new_causal_mod(x, fitted, idx)
}

#' @describeIn causal_mod Return `TRUE` if an object is an `causal_mod` list
#' @export
is_causal_mod <- function(x) {
inherits(x, "causal_mod")
}


# printing
#' @export
format.causal_mod <- function(x, ...) {
formatC(vctrs::vec_data(x))
}
#' @export
str.causal_mod <- function(object, max.level=2, ...) {
NextMethod(max.level=max.level, ...) # nocov
}

# vctrs -------------------------------------------------------------------

#' @export
`[.causal_mod` <- function(x, i) {
out <- NextMethod()
attr(out, "idx") <- attr(out, "idx")[i]
out
}

#' @importFrom vctrs vec_ptype_abbr
#' @method vec_ptype_abbr causal_mod
#' @export
vec_ptype_abbr.causal_mod <- function(x, ...) {
"mod" # nocov
}

#' @importFrom vctrs vec_ptype2
#' @export
vec_ptype2.double.causal_mod <- function(x, y, ...) double() # nocov
#' @export
vec_ptype2.causal_mod.double <- function(x, y, ...) double()
#' @importFrom vctrs vec_cast
#' @export
vec_cast.double.causal_mod <- function(x, to, ...) vctrs::vec_data(x)
4 changes: 2 additions & 2 deletions man/causal_idx.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

54 changes: 54 additions & 0 deletions man/causal_mod.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/_snaps/causal_mod.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# causal_mod printing

Code
print(x)
Output
<causal_mod[24]>
[1] 50.89 58.33 51.82 55.06 65.1 55.7 53.33 55.67 58.99 59.02 68.42 56.66
[13] 57.78 48.38 46.01 48.34 54.82 48.32 51.56 47.39 57.38 60.65 53.22 54.15

---

Code
str(x)
Output
mod [1:24] 50.89, 58.33, 51.82, 55.06, 65.1, 55.7, 53.33, 55.67, 58.99, 59...
@ model:List of 13
..- attr(*, "class")= chr "lm"
@ idx : int [1:24] 1 2 3 4 5 6 7 8 9 10 ...

44 changes: 44 additions & 0 deletions tests/testthat/test-causal_mod.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

test_that("causal_mod constructor", {
m <- lm(yield ~ block + N*P*K, data=npk)
x <- causal_mod(m)
expect_s3_class(x, "causal_mod")
expect_type(x, "double")
expect_true(is_causal_mod(x))

expect_error(causal_mod(5), "atomic")
expect_error(new_causal_mod(5), "atomic")
expect_error(causal_mod(list()), "fitted()")

d = rbind(NA, npk, NA, npk)
m <- lm(yield ~ block + N*P*K, data=d)
x <- causal_mod(m)
expect_equal(which(is.na(x)), unname(c(na.action(m))))
})

test_that("causal_mod conversion", {
m <- lm(yield ~ block + N*P*K, data=npk)
x <- causal_mod(m)

expect_type(as.double(x), "double")
expect_type(c(x, double()), "double")
expect_type(c(double(), x), "double")
})

test_that("causal_mod slicing", {
m <- lm(yield ~ block + N*P*K, data=npk)
x <- causal_mod(m)

expect_equal(as.double(x[4:2]), unname(fitted(m)[4:2]))
expect_equal(attr(x[4:2], "idx"), 4:2)
})


test_that("causal_mod printing", {
m <- lm(yield ~ block + N*P*K, data=npk)
x <- causal_mod(m)

expect_snapshot(print(x))
expect_snapshot(str(x))
})

0 comments on commit 4621854

Please sign in to comment.