diff --git a/NAMESPACE b/NAMESPACE index 07e6720..5399c4f 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -63,7 +63,7 @@ export(pull_outcome) export(pull_panel_time) export(pull_panel_unit) export(pull_treatment) -export(set_causal_col) +export(set_causal_cols) export(set_outcome) export(set_panel) export(set_treatment) diff --git a/R/causal_cols.R b/R/causal_cols.R index cf26683..3503a1f 100644 --- a/R/causal_cols.R +++ b/R/causal_cols.R @@ -6,7 +6,7 @@ #' @param data A [causal_tbl]. #' @param value New value for `causal_cols`. #' @param ... Named attributes to add to `data`'s causal attributes. -#' @param what The causal column to get or set. +#' @param what The causal column to get or add. #' @param ptype A type to coerce a single added column to. #' #' @returns Varies. Setter methods return the original `data`, perhaps invisibly. @@ -16,9 +16,10 @@ #' milk_first = c(0, 1, 0, 1, 1, 0, 0, 1), #' guess = c(0, 1, 0, 1, 1, 0, 0, 1) #' ) |> -#' set_causal_col("treatment", guess=milk_first) +#' set_causal_cols(outcomes=guess, treatments=c(guess=milk_first)) #' print(data) -#' get_causal_col(data, "treatment") +#' get_causal_col(data, "treatments") +#' get_causal_col(data, "outcomes") #' causal_cols(data) #' @export causal_cols <- function(data) { @@ -33,11 +34,14 @@ causal_cols <- function(data) { #' @describeIn causal_cols Set column(s) for a `causal_col` #' @export -set_causal_col <- function(data, what, ...) { +set_causal_cols <- function(data, ...) { data <- as_causal_tbl(data) - dots <- rlang::quo(c(...)) - cols <- multi_col_name(dots, data, what) - causal_cols(data)[[what]] <- cols + dots <- rlang::enquos(...) + for (i in seq_along(dots)) { + what <- names(dots)[i] + cols <- multi_col_name(dots[[i]], data, what) + causal_cols(data)[[what]] <- cols + } data } #' @describeIn causal_cols Add a single column to a `causal_col` @@ -47,7 +51,7 @@ add_causal_col <- function(data, what, ..., ptype=NULL) { dots <- enquos(...) if (length(dots) > 1) { cli_abort(c("Only one column can be added at a time.", - ">"="Use {.fn set_causal_col} to add more than one column."), + ">"="Use {.fn set_causal_cols} to add more than one column."), call=parent.frame()) } diff --git a/man/causal_cols.Rd b/man/causal_cols.Rd index 802e3d2..af19c5d 100644 --- a/man/causal_cols.Rd +++ b/man/causal_cols.Rd @@ -3,7 +3,7 @@ \name{causal_cols} \alias{causal_cols} \alias{causal_cols<-} -\alias{set_causal_col} +\alias{set_causal_cols} \alias{add_causal_col} \alias{get_causal_col} \title{Work directly with \code{causal_cols}} @@ -12,7 +12,7 @@ causal_cols(data) causal_cols(data) <- value -set_causal_col(data, what, ...) +set_causal_cols(data, ...) add_causal_col(data, what, ..., ptype = NULL) @@ -23,10 +23,10 @@ get_causal_col(data, what) \item{value}{New value for \code{causal_cols}.} -\item{what}{The causal column to get or set.} - \item{...}{Named attributes to add to \code{data}'s causal attributes.} +\item{what}{The causal column to get or add.} + \item{ptype}{A type to coerce a single added column to.} } \value{ @@ -40,7 +40,7 @@ functionality. \itemize{ \item \code{causal_cols(data) <- value}: Set \code{causal_cols} -\item \code{set_causal_col()}: Set column(s) for a \code{causal_col} +\item \code{set_causal_cols()}: Set column(s) for a \code{causal_col} \item \code{add_causal_col()}: Add a single column to a \code{causal_col} @@ -52,8 +52,9 @@ data <- data.frame( milk_first = c(0, 1, 0, 1, 1, 0, 0, 1), guess = c(0, 1, 0, 1, 1, 0, 0, 1) ) |> - set_causal_col("treatment", guess=milk_first) + set_causal_cols(outcomes=guess, treatments=c(guess=milk_first)) print(data) -get_causal_col(data, "treatment") +get_causal_col(data, "treatments") +get_causal_col(data, "outcomes") causal_cols(data) } diff --git a/tests/testthat/test-causal_cols.R b/tests/testthat/test-causal_cols.R index c7f6fed..cb16080 100644 --- a/tests/testthat/test-causal_cols.R +++ b/tests/testthat/test-causal_cols.R @@ -91,7 +91,7 @@ test_that("lower-level getting and setting", { y = c(1, 3, 2, 3, 2, 4, 4, 5) ) - x2 <- set_causal_col(x, "treatments", y=trt, id=year) + x2 <- set_causal_cols(x, treatments=c(y=trt, id=year)) expect_equal(get_treatment(x2), c(y="trt")) expect_length(causal_cols(x2)$treatments, 2) @@ -100,10 +100,10 @@ test_that("lower-level getting and setting", { expect_length(causal_cols(x3)$treatments, 1) expect_error(add_causal_col(x, "treatments", y=trt, y2=trt2), "more than one") - x3 <- set_causal_col(x3, "treatments", id=year) + x3 <- set_causal_cols(x3, treatments=c(id=year)) expect_equal(get_treatment(x3), c(id="year")) expect_length(causal_cols(x3)$treatments, 1) - expect_error(set_causal_col(x3, "treatments"), "Must select") + expect_error(set_causal_cols(x3, treatments=NULL), "Must select") x4 <- add_causal_col(x, "pscore", trt=id, ptype=factor()) expect_type(causal_cols(x4)$pscore, "character") diff --git a/tests/testthat/test-causal_tbl.R b/tests/testthat/test-causal_tbl.R index 8bab351..3276019 100644 --- a/tests/testthat/test-causal_tbl.R +++ b/tests/testthat/test-causal_tbl.R @@ -84,10 +84,16 @@ test_that("causal_tbl slicing and renaming", { x_df = as.data.frame(x) expect_equal(x[1:2, ], causal_tbl(x_df[1:2, ], .outcome=y, .treatment=trt)) + expect_identical(x, x[]) + + y <- x[1:4, ] + expect_equal(get_treatment(y), c(y="trt")) + expect_equal(get_outcome(y), "y") + y <- x[2] expect_null(get_treatment(y)) expect_equal(get_outcome(y), "y") - y <- x[1] + y <- x[, 1] expect_null(get_outcome(y)) expect_equal(get_treatment(y), "trt") # no y= !