Skip to content

Commit

Permalink
generic causal_col fns
Browse files Browse the repository at this point in the history
  • Loading branch information
CoryMcCartan committed Mar 25, 2023
1 parent 2022f5a commit 3ba4dd5
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 86 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Description: Provides a 'causal_tbl' class for causal inference. A 'causal_tbl'
outcome, and provides functionality to store models and their fitted
values as columns in a data frame.
Imports:
rlang,
cli,
tibble (>= 3.0.0),
tidyselect,
Expand Down
5 changes: 4 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ S3method(ctl_new_pillar,causal_tbl)
S3method(tbl_format_header,causal_tbl)
S3method(tbl_format_setup,causal_tbl)
S3method(tbl_sum,causal_tbl)
export("causal_cols<-")
export(add_causal_col)
export(as_causal_tbl)
export(causal_cols)
export(causal_tbl)
export(get_causal_col)
export(get_outcome)
Expand All @@ -28,4 +30,5 @@ importFrom(pillar,ctl_new_pillar)
importFrom(pillar,tbl_format_header)
importFrom(pillar,tbl_format_setup)
importFrom(pillar,tbl_sum)
importFrom(tidyselect,enquo)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
115 changes: 78 additions & 37 deletions R/causal_cols.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,78 @@
# internal accessors
#' Work directly with `causal_cols`
#'
#' These functions are aimed at developers who wish to extend `causal_tbl`
#' functionality.
#'
#' @param data A [causal_tbl].
#' @param value New value for `causal_cols`.
#' @param ... Named attributes to add to `data`'s causal attributes.
#' @param what The causal column to get or set.
#' @param ptype A type to coerce a single added column to.
#'
#' @returns Varies. Setter methods return the original `data`, perhaps invisibly.
#'
#' @examples
#' data <- data.frame(
#' milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
#' guess = c(0, 1, 0, 1, 1, 0, 0, 1)
#' ) |>
#' set_causal_col("treatment", guess=milk_first)
#' print(data)
#' get_causal_col(data, "treatment")
#' causal_cols(data)
#' @export
causal_cols <- function(data) {
attr(data, "causal_cols")
}
#' @describeIn causal_cols Set `causal_cols`
#' @export
`causal_cols<-` = function(data, value) {
attr(data, "causal_cols") <- value
data
}

#' @describeIn causal_cols Set column(s) for a `causal_col`
#' @export
set_causal_col <- function(data, what, ...) {
data <- as_causal_tbl(data)
dots <- rlang::quo(c(...))
cols <- multi_col_name(dots, data, what)
causal_cols(data)[[what]] <- cols
data
}
#' @describeIn causal_cols Add a single column to a `causal_col`
#' @export
add_causal_col <- function(data, what, ..., ptype=NULL) {
data <- as_causal_tbl(data)
dots <- enquos(...)
if (length(dots) > 1) {
cli_abort("Use {.fn set_causal_col} to add more than one column at a time")
}

col <- single_col_name(dots[[1]], data, what)
names(col) = names(dots)

if (what %in% names(causal_cols(data))) {
causal_cols(data)[[what]] <- c(causal_cols(data)[[what]], col)
} else {
causal_cols(data)[[what]] <- col
}

if (!is.null(ptype)) {
data[[col]] <- vctrs::vec_cast(data[[col]], ptype, x_arg=col)
}
data
}

#' @describeIn causal_cols Get the column name of the requested variable
#' @export
get_causal_col <- function(data, what) {
causal_cols(data)[[what]]
}





#' Define an outcome variable for a `causal_tbl`
#'
Expand Down Expand Up @@ -145,42 +211,6 @@ has_panel <- function(data) {
}


#' Define Custom Causal Attributes
#'
#' @param data a data frame or `causal_tbl`
#' @param ... named attributes to add to `data`'s causal attributes
#'
#' @return A `causal_tbl`
#' @export
#'
#' @examples
#' data <- data.frame(
#' milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
#' guess = c(0, 1, 0, 1, 1, 0, 0, 1)
#' ) |>
#' set_causal_col(milk_first, "treatment")
#' print(data) # a causal_tbl
#' get_causal_col(data, "treatment")
set_causal_col <- function(data, ...) {
data <- as_causal_tbl(data)

data
}

#' @rdname set_causal_col
#' @return For `get_causal_col()` the column name of the requested variable
#' @param what the causal attribute to get
#' @export
get_causal_col <- function(data, what) {

}

#' @rdname set_causal_col
#' @return For `add_causal_col()` A `causal_tbl`
#' @export
add_causal_col <- function(data, ...) {

}


# Helper
Expand All @@ -193,3 +223,14 @@ single_col_name <- function(expr, data, arg) {
}
names(data)[idx]
}

# Helper
multi_col_name <- function(expr, data, arg) {
idx <- tidyselect::eval_select(expr, data, allow_rename=TRUE)
if (length(idx) == 0) {
cli_abort("Must select a column for {.arg {arg}}", call=parent.frame())
}
out <- names(data)[idx]
names(out) = names(idx)
out
}
15 changes: 8 additions & 7 deletions R/causal_tbl.R
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ tbl_format_header.causal_tbl <- function(x, setup, ...) {
#' @method tbl_format_setup causal_tbl
#' @export
tbl_format_setup.causal_tbl <- function(x, width, ..., n, max_extra_cols, max_footer_lines, focus) {
NextMethod(focus=unlist(causal_cols(x)))
NextMethod(focus=unique(unlist(causal_cols(x))))
}


Expand All @@ -258,14 +258,15 @@ ctl_new_pillar.causal_tbl <- function(controller, x, width, ..., title = NULL) {
out <- NextMethod()
cols <- causal_cols(controller)
matched_types = vapply(cols, function(y) match(title, y)[1], 0L)
marker_type = names(which(!is.na(matched_types)))
marker = if (length(marker_type) == 0 || is.na(marker_type)) {
marker_type = names(which(!is.na(matched_types)))[1] # first match only
marker = c(
outcome="[out]", treatment="[trt]",
panel_unit="[unit]", panel_time="[time]"
)[marker_type]
marker = if (length(marker) == 0 || is.na(marker)) {
""
} else {
pillar::style_subtle(c(
outcome="[out]", treatment="[trt]",
panel_unit="[unit]", panel_time="[time]"
)[marker_type])
pillar::style_subtle(marker)
}

pillar::new_pillar(list(
Expand Down
2 changes: 1 addition & 1 deletion R/causaltbl-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

## usethis namespace: start
#' @importFrom cli cli_abort cli_warn cli_inform
#' @importFrom tidyselect enquo
#' @importFrom rlang enquo enquos
## usethis namespace: end
NULL
59 changes: 59 additions & 0 deletions man/causal_cols.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

40 changes: 0 additions & 40 deletions man/set_causal_col.Rd

This file was deleted.

28 changes: 28 additions & 0 deletions tests/testthat/test-causal_cols.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,31 @@ test_that("getting and setting panel data", {
expect_error(set_panel(x, unit=id, time=not_a_column), "doesn't exist")
})

test_that("lower-level getting and setting", {
x <- causal_tbl(
id = c("a", "a", "a", "a", "b", "b", "b", "b"),
year = rep(2015:2018, 2),
trt = c(0, 0, 0, 0, 0, 0, 1, 1),
y = c(1, 3, 2, 3, 2, 4, 4, 5)
)

x2 <- set_causal_col(x, "treatment", y=trt, id=year)
expect_equal(get_treatment(x2), c(y="trt"))
expect_length(causal_cols(x2)$treatment, 2)

x3 <- add_causal_col(x, "treatment", y=trt)
expect_equal(get_treatment(x3), c(y="trt"))
expect_length(causal_cols(x3)$treatment, 1)
expect_error(add_causal_col(x, "treatment", y=trt, y2=trt2), "more than one")

x3 <- set_causal_col(x3, "treatment", id=year)
expect_equal(get_treatment(x3), c(id="year"))
expect_length(causal_cols(x3)$treatment, 1)
expect_error(set_causal_col(x3, "treatment"), "Must select")

x4 <- add_causal_col(x, "pscore", trt=id, ptype=factor())
expect_type(causal_cols(x4)$pscore, "character")
expect_s3_class(x4[[causal_cols(x4)$pscore]], "factor")
expect_equal(get_causal_col(x4, "pscore"), c(trt="id"))
})

0 comments on commit 3ba4dd5

Please sign in to comment.