Skip to content

Commit

Permalink
upgrade set_causal_cols(); add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CoryMcCartan committed Mar 27, 2023
1 parent 3f37530 commit 35d5331
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 20 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ export(pull_outcome)
export(pull_panel_time)
export(pull_panel_unit)
export(pull_treatment)
export(set_causal_col)
export(set_causal_cols)
export(set_outcome)
export(set_panel)
export(set_treatment)
Expand Down
20 changes: 12 additions & 8 deletions R/causal_cols.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#' @param data A [causal_tbl].
#' @param value New value for `causal_cols`.
#' @param ... Named attributes to add to `data`'s causal attributes.
#' @param what The causal column to get or set.
#' @param what The causal column to get or add.
#' @param ptype A type to coerce a single added column to.
#'
#' @returns Varies. Setter methods return the original `data`, perhaps invisibly.
Expand All @@ -16,9 +16,10 @@
#' milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
#' guess = c(0, 1, 0, 1, 1, 0, 0, 1)
#' ) |>
#' set_causal_col("treatment", guess=milk_first)
#' set_causal_cols(outcomes=guess, treatments=c(guess=milk_first))
#' print(data)
#' get_causal_col(data, "treatment")
#' get_causal_col(data, "treatments")
#' get_causal_col(data, "outcomes")
#' causal_cols(data)
#' @export
causal_cols <- function(data) {
Expand All @@ -33,11 +34,14 @@ causal_cols <- function(data) {

#' @describeIn causal_cols Set column(s) for a `causal_col`
#' @export
set_causal_col <- function(data, what, ...) {
set_causal_cols <- function(data, ...) {
data <- as_causal_tbl(data)
dots <- rlang::quo(c(...))
cols <- multi_col_name(dots, data, what)
causal_cols(data)[[what]] <- cols
dots <- rlang::enquos(...)
for (i in seq_along(dots)) {
what <- names(dots)[i]
cols <- multi_col_name(dots[[i]], data, what)
causal_cols(data)[[what]] <- cols
}
data
}
#' @describeIn causal_cols Add a single column to a `causal_col`
Expand All @@ -47,7 +51,7 @@ add_causal_col <- function(data, what, ..., ptype=NULL) {
dots <- enquos(...)
if (length(dots) > 1) {
cli_abort(c("Only one column can be added at a time.",
">"="Use {.fn set_causal_col} to add more than one column."),
">"="Use {.fn set_causal_cols} to add more than one column."),
call=parent.frame())
}

Expand Down
15 changes: 8 additions & 7 deletions man/causal_cols.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test-causal_cols.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ test_that("lower-level getting and setting", {
y = c(1, 3, 2, 3, 2, 4, 4, 5)
)

x2 <- set_causal_col(x, "treatments", y=trt, id=year)
x2 <- set_causal_cols(x, treatments=c(y=trt, id=year))
expect_equal(get_treatment(x2), c(y="trt"))
expect_length(causal_cols(x2)$treatments, 2)

Expand All @@ -100,10 +100,10 @@ test_that("lower-level getting and setting", {
expect_length(causal_cols(x3)$treatments, 1)
expect_error(add_causal_col(x, "treatments", y=trt, y2=trt2), "more than one")

x3 <- set_causal_col(x3, "treatments", id=year)
x3 <- set_causal_cols(x3, treatments=c(id=year))
expect_equal(get_treatment(x3), c(id="year"))
expect_length(causal_cols(x3)$treatments, 1)
expect_error(set_causal_col(x3, "treatments"), "Must select")
expect_error(set_causal_cols(x3, treatments=NULL), "Must select")

x4 <- add_causal_col(x, "pscore", trt=id, ptype=factor())
expect_type(causal_cols(x4)$pscore, "character")
Expand Down
8 changes: 7 additions & 1 deletion tests/testthat/test-causal_tbl.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,16 @@ test_that("causal_tbl slicing and renaming", {
x_df = as.data.frame(x)
expect_equal(x[1:2, ], causal_tbl(x_df[1:2, ], .outcome=y, .treatment=trt))

expect_identical(x, x[])

y <- x[1:4, ]
expect_equal(get_treatment(y), c(y="trt"))
expect_equal(get_outcome(y), "y")

y <- x[2]
expect_null(get_treatment(y))
expect_equal(get_outcome(y), "y")
y <- x[1]
y <- x[, 1]
expect_null(get_outcome(y))
expect_equal(get_treatment(y), "trt") # no y= !

Expand Down

0 comments on commit 35d5331

Please sign in to comment.