From 4621854fcf5de6bea93587b13ce24defcfaa7afa Mon Sep 17 00:00:00 2001 From: Cory McCartan Date: Sat, 25 Mar 2023 23:47:48 -0400 Subject: [PATCH] causal_mod --- DESCRIPTION | 3 +- NAMESPACE | 10 +++ R/causal_cols.R | 4 ++ R/causal_idx.R | 6 +- R/causal_mod.R | 105 ++++++++++++++++++++++++++++ man/causal_idx.Rd | 4 +- man/causal_mod.Rd | 54 ++++++++++++++ tests/testthat/_snaps/causal_mod.md | 19 +++++ tests/testthat/test-causal_mod.R | 44 ++++++++++++ 9 files changed, 243 insertions(+), 6 deletions(-) create mode 100644 R/causal_mod.R create mode 100644 man/causal_mod.Rd create mode 100644 tests/testthat/_snaps/causal_mod.md create mode 100644 tests/testthat/test-causal_mod.R diff --git a/DESCRIPTION b/DESCRIPTION index ab72f3e..721dc69 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -14,7 +14,8 @@ Imports: tibble (>= 3.0.0), tidyselect, vctrs, - pillar + pillar, + stats Suggests: dplyr, testthat (>= 3.0.0) diff --git a/NAMESPACE b/NAMESPACE index 5d60707..6030797 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,24 +1,32 @@ # Generated by roxygen2: do not edit by hand S3method("[",causal_idx) +S3method("[",causal_mod) S3method("[",causal_tbl) S3method("names<-",causal_tbl) S3method(ctl_new_pillar,causal_tbl) S3method(format,causal_idx) +S3method(format,causal_mod) +S3method(str,causal_mod) S3method(tbl_format_header,causal_tbl) S3method(tbl_format_setup,causal_tbl) S3method(tbl_sum,causal_tbl) S3method(vec_cast,causal_idx.list) +S3method(vec_cast,double.causal_mod) S3method(vec_cast,list.causal_idx) S3method(vec_ptype2,causal_idx.list) +S3method(vec_ptype2,causal_mod.double) +S3method(vec_ptype2,double.causal_mod) S3method(vec_ptype2,list.causal_idx) S3method(vec_ptype_abbr,causal_idx) +S3method(vec_ptype_abbr,causal_mod) export("causal_cols<-") export(add_causal_col) export(as_causal_idx) export(as_causal_tbl) export(causal_cols) export(causal_idx) +export(causal_mod) export(causal_tbl) export(get_causal_col) export(get_outcome) @@ -28,8 +36,10 @@ export(has_outcome) export(has_panel) export(has_treatment) export(is_causal_idx) +export(is_causal_mod) export(is_causal_tbl) export(new_causal_idx) +export(new_causal_mod) export(new_causal_tbl) export(set_causal_col) export(set_outcome) diff --git a/R/causal_cols.R b/R/causal_cols.R index ce35041..0e0787f 100644 --- a/R/causal_cols.R +++ b/R/causal_cols.R @@ -94,7 +94,9 @@ set_outcome <- function(data, outcome) { data <- as_causal_tbl(data) col <- single_col_name(enquo(outcome), data, "outcome") causal_cols(data)$outcomes <- col + # coerce data[[col]] <- vctrs::vec_cast(data[[col]], numeric(), x_arg=col) + # handle names if (has_treatment(data)) { names(causal_cols(data)$treatments)[1] <- col } @@ -139,7 +141,9 @@ set_treatment <- function(data, treatment, outcome = get_outcome()) { data <- as_causal_tbl(data) col <- single_col_name(enquo(treatment), data, "treatment") causal_cols(data)$treatments <- col + # coerce data[[col]] <- vctrs::vec_cast(data[[col]], numeric(), x_arg=col) + # handle names if (has_outcome(data)) { names(causal_cols(data)$treatments) <- get_outcome(data) } diff --git a/R/causal_idx.R b/R/causal_idx.R index e951470..6eccb7c 100644 --- a/R/causal_idx.R +++ b/R/causal_idx.R @@ -19,11 +19,11 @@ new_causal_idx <- function(x = list()) { #' (because of subsetting) are set to NA. #' #' @param x -#' * For `causal_idx()` and `new_causal_idx()`: A list of indices of type `causal_idx` +#' * For `causal_idx()` and `new_causal_idx()`: A list of indices #' * For `is_causal_idx()`: An object to test #' * For `as_causal_idx()`: An object to coerce #' -#' @returns A `causal_idx` object. +#' @returns A `causal_idx` object. For `is_causal_idx()`, a logical value. #' #' @examples #' idx <- causal_idx(list(2, c(1, NA, 3), 2)) @@ -82,7 +82,7 @@ vec_ptype_abbr.causal_idx <- function(x, ...) { #' @importFrom vctrs vec_ptype2 #' @export -vec_ptype2.list.causal_idx <- function(x, y, ...) list() +vec_ptype2.list.causal_idx <- function(x, y, ...) list() # nocov #' @export vec_ptype2.causal_idx.list <- function(x, y, ...) list() #' @importFrom vctrs vec_cast diff --git a/R/causal_mod.R b/R/causal_mod.R new file mode 100644 index 0000000..76d4bea --- /dev/null +++ b/R/causal_mod.R @@ -0,0 +1,105 @@ +#' @describeIn causal_mod Construct a `causal_mod` with minimal checks +#' @export +new_causal_mod <- function(x = list(), fitted = double(0), idx = seq_along(fitted)) { + if (is.atomic(x)) { + cli_abort("{.arg x} cannot be an atomic type.") + } + + names(fitted) <- NULL + vctrs::new_vctr(fitted[idx], model=x, idx=idx, class="causal_mod") +} + +#' Construct a fitted model column +#' +#' The `causal_mod` class acts like a vector of fitted values, but it also +#' stores the fitted model for later predictions and summaries, and keeps track +#' of observations that have been dropped (e.g. due to missingness) or subsetted. +#' +#' @param x +#' * For `causal_mod()` and `new_causal_mod()`: A fitted model object. +#' For `causal_mod()` this should support the [fitted()] generic. +#' * For `is_causal_mod()`: An object to test +#' @param idx A set of indices that connect observations fed into the model with +#' fitted values. Should contain values from 1 to the number of fitted values, +#' and may contain `NA` values for observations with no corresponding fitted +#' value (such as those with missing data). Defaults to a sequence over the +#' fitted values, with `NA`s determined by [na.action()]. +#' @param fitted A vector of fitted values. Extracted automatically from the +#' model object in `causal_mod()`. +#' +#' @returns A `causal_mod` object. For `is_causal_mod()`, a logical value. +#' +#' @examples +#' m <- lm(yield ~ block + N*P*K, data=npk) +#' causal_mod(m) +#' +#' d <- rbind(NA, npk) +#' m_mis <- lm(yield ~ block + N*P*K, data=d) +#' causal_mod(m_mis) # NA for missing value +#' +#' @order 1 +#' @export +causal_mod <- function(x, idx = NULL) { + if (is.atomic(x)) { + cli_abort("{.arg x} cannot be an atomic type.") + } + + fitted <- stats::fitted(x) + if (is.null(fitted)) { + cli_abort("{.arg x} does not have a {.fn fitted} method.") + } + if (is.null(idx)) { + nas <- na.action(x) + if (is.null(nas)) { + idx = seq_along(fitted) + } else { + idx = integer(length(fitted) + length(nas)) + idx[nas] = NA_integer_ + idx[-nas] = seq_along(fitted) + } + } + + new_causal_mod(x, fitted, idx) +} + +#' @describeIn causal_mod Return `TRUE` if an object is an `causal_mod` list +#' @export +is_causal_mod <- function(x) { + inherits(x, "causal_mod") +} + + +# printing +#' @export +format.causal_mod <- function(x, ...) { + formatC(vctrs::vec_data(x)) +} +#' @export +str.causal_mod <- function(object, max.level=2, ...) { + NextMethod(max.level=max.level, ...) # nocov +} + +# vctrs ------------------------------------------------------------------- + +#' @export +`[.causal_mod` <- function(x, i) { + out <- NextMethod() + attr(out, "idx") <- attr(out, "idx")[i] + out +} + +#' @importFrom vctrs vec_ptype_abbr +#' @method vec_ptype_abbr causal_mod +#' @export +vec_ptype_abbr.causal_mod <- function(x, ...) { + "mod" # nocov +} + +#' @importFrom vctrs vec_ptype2 +#' @export +vec_ptype2.double.causal_mod <- function(x, y, ...) double() # nocov +#' @export +vec_ptype2.causal_mod.double <- function(x, y, ...) double() +#' @importFrom vctrs vec_cast +#' @export +vec_cast.double.causal_mod <- function(x, to, ...) vctrs::vec_data(x) diff --git a/man/causal_idx.Rd b/man/causal_idx.Rd index 8cd059f..bc67a53 100644 --- a/man/causal_idx.Rd +++ b/man/causal_idx.Rd @@ -17,13 +17,13 @@ as_causal_idx(x) } \arguments{ \item{x}{\itemize{ -\item For \code{causal_idx()} and \code{new_causal_idx()}: A list of indices of type \code{causal_idx} +\item For \code{causal_idx()} and \code{new_causal_idx()}: A list of indices \item For \code{is_causal_idx()}: An object to test \item For \code{as_causal_idx()}: An object to coerce }} } \value{ -A \code{causal_idx} object. +A \code{causal_idx} object. For \code{is_causal_idx()}, a logical value. } \description{ A \code{causal_idx} list is a list with integer vector entries which represent diff --git a/man/causal_mod.Rd b/man/causal_mod.Rd new file mode 100644 index 0000000..93977f3 --- /dev/null +++ b/man/causal_mod.Rd @@ -0,0 +1,54 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/causal_mod.R +\name{causal_mod} +\alias{causal_mod} +\alias{new_causal_mod} +\alias{is_causal_mod} +\title{Construct a fitted model column} +\usage{ +causal_mod(x, idx = NULL) + +new_causal_mod(x = list(), fitted = double(0), idx = seq_along(fitted)) + +is_causal_mod(x) +} +\arguments{ +\item{x}{\itemize{ +\item For \code{causal_mod()} and \code{new_causal_mod()}: A fitted model object. +For \code{causal_mod()} this should support the \code{\link[=fitted]{fitted()}} generic. +\item For \code{is_causal_mod()}: An object to test +}} + +\item{idx}{A set of indices that connect observations fed into the model with +fitted values. Should contain values from 1 to the number of fitted values, +and may contain \code{NA} values for observations with no corresponding fitted +value (such as those with missing data). Defaults to a sequence over the +fitted values, with \code{NA}s determined by \code{\link[=na.action]{na.action()}}.} + +\item{fitted}{A vector of fitted values. Extracted automatically from the +model object in \code{causal_mod()}.} +} +\value{ +A \code{causal_mod} object. For \code{is_causal_mod()}, a logical value. +} +\description{ +The \code{causal_mod} class acts like a vector of fitted values, but it also +stores the fitted model for later predictions and summaries, and keeps track +of observations that have been dropped (e.g. due to missingness) or subsetted. +} +\section{Functions}{ +\itemize{ +\item \code{new_causal_mod()}: Construct a \code{causal_mod} with minimal checks + +\item \code{is_causal_mod()}: Return \code{TRUE} if an object is an \code{causal_mod} list + +}} +\examples{ +m <- lm(yield ~ block + N*P*K, data=npk) +causal_mod(m) + +d <- rbind(NA, npk) +m_mis <- lm(yield ~ block + N*P*K, data=d) +causal_mod(m_mis) # NA for missing value + +} diff --git a/tests/testthat/_snaps/causal_mod.md b/tests/testthat/_snaps/causal_mod.md new file mode 100644 index 0000000..0bdd215 --- /dev/null +++ b/tests/testthat/_snaps/causal_mod.md @@ -0,0 +1,19 @@ +# causal_mod printing + + Code + print(x) + Output + + [1] 50.89 58.33 51.82 55.06 65.1 55.7 53.33 55.67 58.99 59.02 68.42 56.66 + [13] 57.78 48.38 46.01 48.34 54.82 48.32 51.56 47.39 57.38 60.65 53.22 54.15 + +--- + + Code + str(x) + Output + mod [1:24] 50.89, 58.33, 51.82, 55.06, 65.1, 55.7, 53.33, 55.67, 58.99, 59... + @ model:List of 13 + ..- attr(*, "class")= chr "lm" + @ idx : int [1:24] 1 2 3 4 5 6 7 8 9 10 ... + diff --git a/tests/testthat/test-causal_mod.R b/tests/testthat/test-causal_mod.R new file mode 100644 index 0000000..e0f9a2d --- /dev/null +++ b/tests/testthat/test-causal_mod.R @@ -0,0 +1,44 @@ + +test_that("causal_mod constructor", { + m <- lm(yield ~ block + N*P*K, data=npk) + x <- causal_mod(m) + expect_s3_class(x, "causal_mod") + expect_type(x, "double") + expect_true(is_causal_mod(x)) + + expect_error(causal_mod(5), "atomic") + expect_error(new_causal_mod(5), "atomic") + expect_error(causal_mod(list()), "fitted()") + + d = rbind(NA, npk, NA, npk) + m <- lm(yield ~ block + N*P*K, data=d) + x <- causal_mod(m) + expect_equal(which(is.na(x)), unname(c(na.action(m)))) +}) + +test_that("causal_mod conversion", { + m <- lm(yield ~ block + N*P*K, data=npk) + x <- causal_mod(m) + + expect_type(as.double(x), "double") + expect_type(c(x, double()), "double") + expect_type(c(double(), x), "double") +}) + +test_that("causal_mod slicing", { + m <- lm(yield ~ block + N*P*K, data=npk) + x <- causal_mod(m) + + expect_equal(as.double(x[4:2]), unname(fitted(m)[4:2])) + expect_equal(attr(x[4:2], "idx"), 4:2) +}) + + +test_that("causal_mod printing", { + m <- lm(yield ~ block + N*P*K, data=npk) + x <- causal_mod(m) + + expect_snapshot(print(x)) + expect_snapshot(str(x)) +}) +