Skip to content

Commit

Permalink
coverage + slice support + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CoryMcCartan committed Mar 25, 2023
1 parent e4eef64 commit 5949703
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
upload-snapshots: true

- name: Test coverage
if: matrix.os == 'ubuntu-latest'
if: runner.os == 'Linux'
run: |
covr::codecov(
quiet = FALSE,
Expand Down
2 changes: 1 addition & 1 deletion R/causal_cols.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ set_outcome <- function(data, outcome) {
causal_cols(data)$outcome <- col
data[[col]] <- vctrs::vec_cast(data[[col]], numeric(), x_arg=col)
if (has_treatment(data)) {
names(causal_cols(data)$treatment[1]) <- col
names(causal_cols(data)$treatment)[1] <- col
}
data
}
Expand Down
50 changes: 26 additions & 24 deletions R/causal_tbl.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ validate_causal_tbl <- function(data, call = parent.frame()) {
if (!"outcome" %in% names(cols))
cli_abort("Missing `outcome` in causal_cols", call=call)
if (!"treatment" %in% names(cols))
cli_abort("Missing `outcome` in causal_cols", call=call)
cli_abort("Missing `treatment` in causal_cols", call=call)
if (!is.null(cols$outcome)) {
if (!is.character(cols$outcome))
cli_abort("The `outcome` causal_cols must be stored as a string.", call=call)
Expand All @@ -48,7 +48,7 @@ validate_causal_tbl <- function(data, call = parent.frame()) {
reconstruct.causal_tbl <- function(data, old) {
classes <- c("tbl_df", "tbl", "data.frame")
if (!is.data.frame(data)) {
cli_abort("{.arg {deparse(substittue(data))}} must be a data frame.",
cli_abort("{.arg {deparse(substitute(data))}} must be a data frame.",
call=parent.frame())
}

Expand All @@ -66,24 +66,6 @@ reconstruct.causal_tbl <- function(data, old) {
}
}

# copy causal_col from old object as needed/available
# if (!missing(old)) {
if (FALSE) {
if (!is.null(col <- get_outcome(old)) && col %in% names(data)) {
data <- set_outcome(data, col)
}
if (!is.null(col <- get_treatment(old)) && col %in% names(data)) {
data <- set_treatment(data, col)
}

other_csl_cols <- setdiff(names(causal_cols(old)), names(causal_cols(data)))
if (length(other_csl_cols) > 1) {
for (i in seq_len(length(other_csl_cols))) {
causal_cols(data)[[other_csl_cols[i]]] <- causal_cols(old)[[other_csl_cols[i]]]
}
}
}

class(data) <- c("causal_tbl", classes)
data
}
Expand Down Expand Up @@ -189,15 +171,31 @@ assert_df <- function(data, arg) {

#' @export
`[.causal_tbl` <- function(x, i) {
old_names <- names(x)
new_names <- names(x)[i]
out <- NextMethod()

cols <- causal_cols(x)
for (col in names(cols)) {
if (is.null(cols[[col]])) next
if (!cols[[col]] %in% old_names[i]) {
causal_cols(out)[[col]] = NULL
# figure out what to subset to
keep = cols[[col]] %in% new_names
new_col = cols[[col]][keep]
# handle subsetting
if (length(new_col) == 0) { # causal_col removed (set to NULL)
new_col = list(NULL)
names(new_col) = col
causal_cols(out)[col] = new_col
next
} else if (!is.null(names(new_col))) {
nms <- new_names[match(names(new_col), new_names)]
if (all(is.na(nms))) {
nms = NULL
} else {
nms[is.na(nms)] = ""
}
names(new_col) = nms
}
causal_cols(out)[[col]] = new_col
}

out
Expand All @@ -211,7 +209,11 @@ assert_df <- function(data, arg) {
cols <- causal_cols(x)
for (col in names(cols)) {
if (is.null(cols[[col]])) next
causal_cols(out)[[col]] = value[which(cols[[col]] == old_names)]
new_col = value[which(cols[[col]] == old_names)]
if (!is.null(names(cols[[col]]))) {
names(new_col) = value[which(names(cols[[col]]) == old_names)]
}
causal_cols(out)[[col]] = new_col
}

out
Expand Down
2 changes: 2 additions & 0 deletions R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ dplyr_reconstruct.causal_tbl <- function(data, template) {
reconstruct.causal_tbl(data, template)
}

# nocov start
register_s3_dplyr <- function() {
vctrs::s3_register("dplyr::dplyr_reconstruct", "causal_tbl")
vctrs::s3_register("dplyr::group_by", "causal_tbl")
vctrs::s3_register("dplyr::ungroup", "causal_tbl")
vctrs::s3_register("dplyr::rowwise", "causal_tbl")
}
# nocov end
2 changes: 2 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# nocov start
.onLoad <- function(...) {
register_s3_dplyr()
}
# nocov end
18 changes: 18 additions & 0 deletions tests/testthat/_snaps/causal_tbl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# causal_tbl printing

Code
print(x)
Output
# A <causal_tbl> [8 x 2]
[trt] [out]
milk_first guess
<dbl> <dbl>
1 0 0
2 1 1
3 0 0
4 1 1
5 1 1
6 0 0
7 0 0
8 1 1

4 changes: 4 additions & 0 deletions tests/testthat/test-causal_cols.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ test_that("getting and setting treatment works", {
expect_no_error(validate_causal_tbl(x_trt))
expect_equal(get_treatment(x_trt), "milk_first")

x_trt_out = set_outcome(x_trt, guess)
expect_equal(get_treatment(x_trt_out), c(guess="milk_first"))

expect_error(set_treatment(x, not_a_column), "doesn't exist")
expect_error(set_treatment(x, c(milk_first, guess)), "Only one")
expect_error(set_treatment(x, NULL), "Must select")
})

Expand Down
91 changes: 91 additions & 0 deletions tests/testthat/test-causal_tbl.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,44 @@ test_that("causal_tbl creation", {
expect_type(causal_cols(x), "list")
expect_null(get_outcome(x))
expect_null(get_treatment(x))

expect_s3_class(causal_tbl(), "causal_tbl") # test empty
x = as_causal_tbl(list(y=1))
expect_s3_class(x, c("causal_tbl", "tbl_df", "tbl", "data.frame"), exact=TRUE)
})

test_that("causal_tbl validation", {
expect_error(validate_causal_tbl(data.frame()), "must have a")

x = causal_tbl()
causal_cols(x) <- list(treatment=NULL) # remove outcome
expect_error(validate_causal_tbl(x), "Missing \\`outcome")
causal_cols(x) <- list(outcome=NULL) # remove treatment
expect_error(validate_causal_tbl(x), "Missing \\`treatment")

x = causal_tbl(y="5")
causal_cols(x) <- list(outcome="y", treatment=NULL)
expect_error(validate_causal_tbl(x), "must be numeric")
causal_cols(x) <- list(outcome=5L, treatment=NULL)
expect_error(validate_causal_tbl(x), "as a string")

x = causal_tbl(t="5")
causal_cols(x) <- list(outcome=NULL, treatment="y")
expect_error(validate_causal_tbl(x), "must be numeric")
causal_cols(x) <- list(outcome=NULL, treatment=5L)
expect_error(validate_causal_tbl(x), "as a string")

x = data.frame()
expect_no_error(assert_df(x))
expect_error(assert_df(5L), "data frame")
expect_error(assert_causal_tbl(x), "causal_tbl")

x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
guess = c(0, 1, 0, 1, 1, 0, 0, 1),
.outcome = "guess")
y <- data.frame(x)
expect_identical(reconstruct.causal_tbl(y, x), x)
expect_error(reconstruct.causal_tbl(5L, x), "data frame")
})

test_that("causal_tbl attributes", {
Expand Down Expand Up @@ -38,3 +76,56 @@ test_that("causal_tbl attributes", {
expect_equal(get_outcome(x), "guess")
expect_equal(get_treatment(x), c(guess="milk_first"))
})

test_that("causal_tbl slicing and renaming", {
x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
guess = c(0, 1, 0, 1, 1, 0, 0, 1),
.outcome = guess,
.treatment = milk_first)

names(x) <- c("trt", "y")
expect_equal(get_treatment(x), c(y="trt"))
expect_equal(get_outcome(x), "y")

y <- x[2]
expect_null(get_treatment(y))
expect_equal(get_outcome(y), "y")
y <- x[1]
expect_null(get_outcome(y))
expect_equal(get_treatment(y), "trt") # no y= !

x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
guess = c(0, 1, 0, 1, 1, 0, 0, 1),
.outcome = guess)
names(x) <- c("trt", "y")
x <- x[2]
expect_null(get_treatment(x))
expect_equal(get_outcome(x), "y")
})

test_that("causal_tbl printing", {
x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
guess = c(0, 1, 0, 1, 1, 0, 0, 1),
.outcome = guess,
.treatment = milk_first)

expect_snapshot(print(x))
})

test_that("causal_tbl + dplyr", {
x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
guess = c(0, 1, 0, 1, 1, 0, 0, 1),
.outcome = guess)

expect_s3_class(dplyr::group_by(x, guess), "causal_tbl")
expect_equal(get_outcome(dplyr::group_by(x, guess)), "guess")
expect_s3_class(dplyr::ungroup(x), "causal_tbl")
expect_s3_class(dplyr::rowwise(x), "causal_tbl")

y <- data.frame(milk_first=c(0, 0, 1, 1),
guess=c(0, 1, 0, 1),
correct=c(1, 0, 0, 1))
out <- dplyr::left_join(x, y, by=c("milk_first", "guess"))
expect_s3_class(out, "causal_tbl")
expect_equal(get_outcome(out), "guess")
})

0 comments on commit 5949703

Please sign in to comment.