Skip to content

Commit

Permalink
arithmetic on fitted values
Browse files Browse the repository at this point in the history
  • Loading branch information
CoryMcCartan committed Mar 27, 2023
1 parent 2b220fd commit 9a52c32
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 1 deletion.
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ S3method(tbl_format_setup,causal_tbl)
S3method(tbl_sum,causal_tbl)
S3method(tidy,causal_mod)
S3method(vcov,causal_mod)
S3method(vec_arith,causal_mod)
S3method(vec_arith.causal_mod,causal_mod)
S3method(vec_arith.causal_mod,default)
S3method(vec_arith.numeric,causal_mod)
S3method(vec_cast,causal_idx.list)
S3method(vec_cast,double.causal_mod)
S3method(vec_cast,list.causal_idx)
Expand Down Expand Up @@ -84,6 +88,8 @@ importFrom(stats,predict)
importFrom(stats,residuals)
importFrom(stats,simulate)
importFrom(stats,vcov)
importFrom(vctrs,vec_arith)
importFrom(vctrs,vec_arith.numeric)
importFrom(vctrs,vec_cast)
importFrom(vctrs,vec_ptype2)
importFrom(vctrs,vec_ptype_abbr)
24 changes: 24 additions & 0 deletions R/causal_mod.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,30 @@ vec_ptype2.causal_mod.double <- function(x, y, ...) double()
vec_cast.double.causal_mod <- function(x, to, ...) vctrs::vec_data(x)


#' @importFrom vctrs vec_arith
#' @method vec_arith causal_mod
#' @export
vec_arith.causal_mod <- function(op, x, y, ...) {
UseMethod("vec_arith.causal_mod", y)
}
#' @method vec_arith.causal_mod default
#' @export
vec_arith.causal_mod.default <- function(op, x, y, ...) {
vctrs::vec_arith_base(op, x, vctrs::vec_data(y), ...)
}
#' @method vec_arith.causal_mod causal_mod
#' @export
vec_arith.causal_mod.causal_mod <- function(op, x, y, ...) {
vctrs::vec_arith_base(op, vctrs::vec_data(x), vctrs::vec_data(y), ...)
}
#' @importFrom vctrs vec_arith.numeric
#' @method vec_arith.numeric causal_mod
#' @export
vec_arith.numeric.causal_mod <- function(op, x, y, ...) {
vctrs::vec_arith_base(op, vctrs::vec_data(x), y, ...)
}


# model generics -------------------------------------------------------------------

#' @importFrom stats fitted
Expand Down
10 changes: 9 additions & 1 deletion tests/testthat/test-causal_mod.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ test_that("causal_mod constructor", {
expect_s3_class(x, "causal_mod")
expect_type(x, "double")
expect_true(is_causal_mod(x))
expect_equal(length(x), nrow(npk))

expect_error(causal_mod(5), "atomic")
expect_error(new_causal_mod(5), "atomic")
Expand Down Expand Up @@ -45,7 +46,6 @@ test_that("causal_mod slicing", {
expect_equal(attr(x[4:2], "idx"), 4:2)
})


test_that("causal_mod generics", {
m <- lm(yield ~ block + N, data=npk)
x <- causal_mod(m)
Expand All @@ -62,6 +62,14 @@ test_that("causal_mod generics", {
expect_identical(formula(x), formula(m))
})

test_that("causal_mod arithmetic", {
m <- lm(yield ~ block + N, data=npk)
x <- causal_mod(m)

expect_equal(x - x, rep(0, nrow(npk)))
expect_equal(x + 1, unname(fitted(m)) + 1)
expect_equal(1 + x, unname(fitted(m)) + 1)
})

test_that("causal_mod printing", {
skip_on_cran()
Expand Down

0 comments on commit 9a52c32

Please sign in to comment.