From 59497037d532e41d80c3c389217bc98b80f7286c Mon Sep 17 00:00:00 2001 From: Cory McCartan Date: Fri, 24 Mar 2023 23:40:22 -0400 Subject: [PATCH] coverage + slice support + tests --- .github/workflows/R-CMD-check.yaml | 2 +- R/causal_cols.R | 2 +- R/causal_tbl.R | 50 ++++++++-------- R/dplyr.R | 2 + R/zzz.R | 2 + tests/testthat/_snaps/causal_tbl.md | 18 ++++++ tests/testthat/test-causal_cols.R | 4 ++ tests/testthat/test-causal_tbl.R | 91 +++++++++++++++++++++++++++++ 8 files changed, 145 insertions(+), 26 deletions(-) create mode 100644 tests/testthat/_snaps/causal_tbl.md diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 9800f71..30c9b93 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -50,7 +50,7 @@ jobs: upload-snapshots: true - name: Test coverage - if: matrix.os == 'ubuntu-latest' + if: runner.os == 'Linux' run: | covr::codecov( quiet = FALSE, diff --git a/R/causal_cols.R b/R/causal_cols.R index 349bf46..9f61fd1 100644 --- a/R/causal_cols.R +++ b/R/causal_cols.R @@ -31,7 +31,7 @@ set_outcome <- function(data, outcome) { causal_cols(data)$outcome <- col data[[col]] <- vctrs::vec_cast(data[[col]], numeric(), x_arg=col) if (has_treatment(data)) { - names(causal_cols(data)$treatment[1]) <- col + names(causal_cols(data)$treatment)[1] <- col } data } diff --git a/R/causal_tbl.R b/R/causal_tbl.R index c4f0230..511c679 100644 --- a/R/causal_tbl.R +++ b/R/causal_tbl.R @@ -28,7 +28,7 @@ validate_causal_tbl <- function(data, call = parent.frame()) { if (!"outcome" %in% names(cols)) cli_abort("Missing `outcome` in causal_cols", call=call) if (!"treatment" %in% names(cols)) - cli_abort("Missing `outcome` in causal_cols", call=call) + cli_abort("Missing `treatment` in causal_cols", call=call) if (!is.null(cols$outcome)) { if (!is.character(cols$outcome)) cli_abort("The `outcome` causal_cols must be stored as a string.", call=call) @@ -48,7 +48,7 @@ validate_causal_tbl <- function(data, call = parent.frame()) { reconstruct.causal_tbl <- function(data, old) { classes <- c("tbl_df", "tbl", "data.frame") if (!is.data.frame(data)) { - cli_abort("{.arg {deparse(substittue(data))}} must be a data frame.", + cli_abort("{.arg {deparse(substitute(data))}} must be a data frame.", call=parent.frame()) } @@ -66,24 +66,6 @@ reconstruct.causal_tbl <- function(data, old) { } } - # copy causal_col from old object as needed/available - # if (!missing(old)) { - if (FALSE) { - if (!is.null(col <- get_outcome(old)) && col %in% names(data)) { - data <- set_outcome(data, col) - } - if (!is.null(col <- get_treatment(old)) && col %in% names(data)) { - data <- set_treatment(data, col) - } - - other_csl_cols <- setdiff(names(causal_cols(old)), names(causal_cols(data))) - if (length(other_csl_cols) > 1) { - for (i in seq_len(length(other_csl_cols))) { - causal_cols(data)[[other_csl_cols[i]]] <- causal_cols(old)[[other_csl_cols[i]]] - } - } - } - class(data) <- c("causal_tbl", classes) data } @@ -189,15 +171,31 @@ assert_df <- function(data, arg) { #' @export `[.causal_tbl` <- function(x, i) { - old_names <- names(x) + new_names <- names(x)[i] out <- NextMethod() cols <- causal_cols(x) for (col in names(cols)) { if (is.null(cols[[col]])) next - if (!cols[[col]] %in% old_names[i]) { - causal_cols(out)[[col]] = NULL + # figure out what to subset to + keep = cols[[col]] %in% new_names + new_col = cols[[col]][keep] + # handle subsetting + if (length(new_col) == 0) { # causal_col removed (set to NULL) + new_col = list(NULL) + names(new_col) = col + causal_cols(out)[col] = new_col + next + } else if (!is.null(names(new_col))) { + nms <- new_names[match(names(new_col), new_names)] + if (all(is.na(nms))) { + nms = NULL + } else { + nms[is.na(nms)] = "" + } + names(new_col) = nms } + causal_cols(out)[[col]] = new_col } out @@ -211,7 +209,11 @@ assert_df <- function(data, arg) { cols <- causal_cols(x) for (col in names(cols)) { if (is.null(cols[[col]])) next - causal_cols(out)[[col]] = value[which(cols[[col]] == old_names)] + new_col = value[which(cols[[col]] == old_names)] + if (!is.null(names(cols[[col]]))) { + names(new_col) = value[which(names(cols[[col]]) == old_names)] + } + causal_cols(out)[[col]] = new_col } out diff --git a/R/dplyr.R b/R/dplyr.R index af0dd30..0592ad8 100644 --- a/R/dplyr.R +++ b/R/dplyr.R @@ -12,9 +12,11 @@ dplyr_reconstruct.causal_tbl <- function(data, template) { reconstruct.causal_tbl(data, template) } +# nocov start register_s3_dplyr <- function() { vctrs::s3_register("dplyr::dplyr_reconstruct", "causal_tbl") vctrs::s3_register("dplyr::group_by", "causal_tbl") vctrs::s3_register("dplyr::ungroup", "causal_tbl") vctrs::s3_register("dplyr::rowwise", "causal_tbl") } +# nocov end diff --git a/R/zzz.R b/R/zzz.R index bd050b1..79390de 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,3 +1,5 @@ +# nocov start .onLoad <- function(...) { register_s3_dplyr() } +# nocov end diff --git a/tests/testthat/_snaps/causal_tbl.md b/tests/testthat/_snaps/causal_tbl.md new file mode 100644 index 0000000..bc37a9f --- /dev/null +++ b/tests/testthat/_snaps/causal_tbl.md @@ -0,0 +1,18 @@ +# causal_tbl printing + + Code + print(x) + Output + # A [8 x 2] + [trt] [out] + milk_first guess + + 1 0 0 + 2 1 1 + 3 0 0 + 4 1 1 + 5 1 1 + 6 0 0 + 7 0 0 + 8 1 1 + diff --git a/tests/testthat/test-causal_cols.R b/tests/testthat/test-causal_cols.R index c3a2df3..b70510c 100644 --- a/tests/testthat/test-causal_cols.R +++ b/tests/testthat/test-causal_cols.R @@ -15,7 +15,11 @@ test_that("getting and setting treatment works", { expect_no_error(validate_causal_tbl(x_trt)) expect_equal(get_treatment(x_trt), "milk_first") + x_trt_out = set_outcome(x_trt, guess) + expect_equal(get_treatment(x_trt_out), c(guess="milk_first")) + expect_error(set_treatment(x, not_a_column), "doesn't exist") + expect_error(set_treatment(x, c(milk_first, guess)), "Only one") expect_error(set_treatment(x, NULL), "Must select") }) diff --git a/tests/testthat/test-causal_tbl.R b/tests/testthat/test-causal_tbl.R index 86562ff..23c5476 100644 --- a/tests/testthat/test-causal_tbl.R +++ b/tests/testthat/test-causal_tbl.R @@ -8,6 +8,44 @@ test_that("causal_tbl creation", { expect_type(causal_cols(x), "list") expect_null(get_outcome(x)) expect_null(get_treatment(x)) + + expect_s3_class(causal_tbl(), "causal_tbl") # test empty + x = as_causal_tbl(list(y=1)) + expect_s3_class(x, c("causal_tbl", "tbl_df", "tbl", "data.frame"), exact=TRUE) +}) + +test_that("causal_tbl validation", { + expect_error(validate_causal_tbl(data.frame()), "must have a") + + x = causal_tbl() + causal_cols(x) <- list(treatment=NULL) # remove outcome + expect_error(validate_causal_tbl(x), "Missing \\`outcome") + causal_cols(x) <- list(outcome=NULL) # remove treatment + expect_error(validate_causal_tbl(x), "Missing \\`treatment") + + x = causal_tbl(y="5") + causal_cols(x) <- list(outcome="y", treatment=NULL) + expect_error(validate_causal_tbl(x), "must be numeric") + causal_cols(x) <- list(outcome=5L, treatment=NULL) + expect_error(validate_causal_tbl(x), "as a string") + + x = causal_tbl(t="5") + causal_cols(x) <- list(outcome=NULL, treatment="y") + expect_error(validate_causal_tbl(x), "must be numeric") + causal_cols(x) <- list(outcome=NULL, treatment=5L) + expect_error(validate_causal_tbl(x), "as a string") + + x = data.frame() + expect_no_error(assert_df(x)) + expect_error(assert_df(5L), "data frame") + expect_error(assert_causal_tbl(x), "causal_tbl") + + x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1), + guess = c(0, 1, 0, 1, 1, 0, 0, 1), + .outcome = "guess") + y <- data.frame(x) + expect_identical(reconstruct.causal_tbl(y, x), x) + expect_error(reconstruct.causal_tbl(5L, x), "data frame") }) test_that("causal_tbl attributes", { @@ -38,3 +76,56 @@ test_that("causal_tbl attributes", { expect_equal(get_outcome(x), "guess") expect_equal(get_treatment(x), c(guess="milk_first")) }) + +test_that("causal_tbl slicing and renaming", { + x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1), + guess = c(0, 1, 0, 1, 1, 0, 0, 1), + .outcome = guess, + .treatment = milk_first) + + names(x) <- c("trt", "y") + expect_equal(get_treatment(x), c(y="trt")) + expect_equal(get_outcome(x), "y") + + y <- x[2] + expect_null(get_treatment(y)) + expect_equal(get_outcome(y), "y") + y <- x[1] + expect_null(get_outcome(y)) + expect_equal(get_treatment(y), "trt") # no y= ! + + x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1), + guess = c(0, 1, 0, 1, 1, 0, 0, 1), + .outcome = guess) + names(x) <- c("trt", "y") + x <- x[2] + expect_null(get_treatment(x)) + expect_equal(get_outcome(x), "y") +}) + +test_that("causal_tbl printing", { + x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1), + guess = c(0, 1, 0, 1, 1, 0, 0, 1), + .outcome = guess, + .treatment = milk_first) + + expect_snapshot(print(x)) +}) + +test_that("causal_tbl + dplyr", { + x <- causal_tbl(milk_first = c(0, 1, 0, 1, 1, 0, 0, 1), + guess = c(0, 1, 0, 1, 1, 0, 0, 1), + .outcome = guess) + + expect_s3_class(dplyr::group_by(x, guess), "causal_tbl") + expect_equal(get_outcome(dplyr::group_by(x, guess)), "guess") + expect_s3_class(dplyr::ungroup(x), "causal_tbl") + expect_s3_class(dplyr::rowwise(x), "causal_tbl") + + y <- data.frame(milk_first=c(0, 0, 1, 1), + guess=c(0, 1, 0, 1), + correct=c(1, 0, 0, 1)) + out <- dplyr::left_join(x, y, by=c("milk_first", "guess")) + expect_s3_class(out, "causal_tbl") + expect_equal(get_outcome(out), "guess") +})