diff --git a/NAMESPACE b/NAMESPACE index 0abb390..07e6720 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -23,6 +23,10 @@ S3method(tbl_format_setup,causal_tbl) S3method(tbl_sum,causal_tbl) S3method(tidy,causal_mod) S3method(vcov,causal_mod) +S3method(vec_arith,causal_mod) +S3method(vec_arith.causal_mod,causal_mod) +S3method(vec_arith.causal_mod,default) +S3method(vec_arith.numeric,causal_mod) S3method(vec_cast,causal_idx.list) S3method(vec_cast,double.causal_mod) S3method(vec_cast,list.causal_idx) @@ -84,6 +88,8 @@ importFrom(stats,predict) importFrom(stats,residuals) importFrom(stats,simulate) importFrom(stats,vcov) +importFrom(vctrs,vec_arith) +importFrom(vctrs,vec_arith.numeric) importFrom(vctrs,vec_cast) importFrom(vctrs,vec_ptype2) importFrom(vctrs,vec_ptype_abbr) diff --git a/R/causal_mod.R b/R/causal_mod.R index 7987343..33a6be5 100644 --- a/R/causal_mod.R +++ b/R/causal_mod.R @@ -159,6 +159,30 @@ vec_ptype2.causal_mod.double <- function(x, y, ...) double() vec_cast.double.causal_mod <- function(x, to, ...) vctrs::vec_data(x) +#' @importFrom vctrs vec_arith +#' @method vec_arith causal_mod +#' @export +vec_arith.causal_mod <- function(op, x, y, ...) { + UseMethod("vec_arith.causal_mod", y) +} +#' @method vec_arith.causal_mod default +#' @export +vec_arith.causal_mod.default <- function(op, x, y, ...) { + vctrs::vec_arith_base(op, x, vctrs::vec_data(y), ...) +} +#' @method vec_arith.causal_mod causal_mod +#' @export +vec_arith.causal_mod.causal_mod <- function(op, x, y, ...) { + vctrs::vec_arith_base(op, vctrs::vec_data(x), vctrs::vec_data(y), ...) +} +#' @importFrom vctrs vec_arith.numeric +#' @method vec_arith.numeric causal_mod +#' @export +vec_arith.numeric.causal_mod <- function(op, x, y, ...) { + vctrs::vec_arith_base(op, vctrs::vec_data(x), y, ...) +} + + # model generics ------------------------------------------------------------------- #' @importFrom stats fitted diff --git a/tests/testthat/test-causal_mod.R b/tests/testthat/test-causal_mod.R index 292532a..25c74a1 100644 --- a/tests/testthat/test-causal_mod.R +++ b/tests/testthat/test-causal_mod.R @@ -5,6 +5,7 @@ test_that("causal_mod constructor", { expect_s3_class(x, "causal_mod") expect_type(x, "double") expect_true(is_causal_mod(x)) + expect_equal(length(x), nrow(npk)) expect_error(causal_mod(5), "atomic") expect_error(new_causal_mod(5), "atomic") @@ -45,7 +46,6 @@ test_that("causal_mod slicing", { expect_equal(attr(x[4:2], "idx"), 4:2) }) - test_that("causal_mod generics", { m <- lm(yield ~ block + N, data=npk) x <- causal_mod(m) @@ -62,6 +62,14 @@ test_that("causal_mod generics", { expect_identical(formula(x), formula(m)) }) +test_that("causal_mod arithmetic", { + m <- lm(yield ~ block + N, data=npk) + x <- causal_mod(m) + + expect_equal(x - x, rep(0, nrow(npk))) + expect_equal(x + 1, unname(fitted(m)) + 1) + expect_equal(1 + x, unname(fitted(m)) + 1) +}) test_that("causal_mod printing", { skip_on_cran()