Skip to content

Commit

Permalink
indices type
Browse files Browse the repository at this point in the history
  • Loading branch information
CoryMcCartan committed Mar 25, 2023
1 parent a664a50 commit 14c3884
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 7 deletions.
15 changes: 15 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# Generated by roxygen2: do not edit by hand

S3method("[",causal_tbl)
S3method("[",indices)
S3method("names<-",causal_tbl)
S3method(ctl_new_pillar,causal_tbl)
S3method(format,indices)
S3method(tbl_format_header,causal_tbl)
S3method(tbl_format_setup,causal_tbl)
S3method(tbl_sum,causal_tbl)
S3method(vec_cast,indices.list)
S3method(vec_cast,list.indices)
S3method(vec_ptype2,indices.list)
S3method(vec_ptype2,list.indices)
S3method(vec_ptype_abbr,indices)
export("causal_cols<-")
export(add_causal_col)
export(as_causal_tbl)
export(as_indices)
export(causal_cols)
export(causal_tbl)
export(get_causal_col)
Expand All @@ -18,7 +26,11 @@ export(get_treatment)
export(has_outcome)
export(has_panel)
export(has_treatment)
export(indices)
export(is_causal_tbl)
export(is_indices)
export(new_causal_tbl)
export(new_indices)
export(set_causal_col)
export(set_outcome)
export(set_panel)
Expand All @@ -32,3 +44,6 @@ importFrom(pillar,tbl_format_setup)
importFrom(pillar,tbl_sum)
importFrom(rlang,enquo)
importFrom(rlang,enquos)
importFrom(vctrs,vec_cast)
importFrom(vctrs,vec_ptype2)
importFrom(vctrs,vec_ptype_abbr)
6 changes: 5 additions & 1 deletion R/causal_cols.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ multi_col_name <- function(expr, data, arg) {
cli_abort("Must select a column for {.arg {arg}}", call=parent.frame())
}
out <- names(data)[idx]
names(out) = names(idx)
nms <- names(idx)
if (!all(nms == out)) {
nms[nms == out] = ""
names(out) <- nms
}
out
}
16 changes: 12 additions & 4 deletions R/causal_tbl.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,12 @@ reconstruct.causal_tbl <- function(data, old) {
#' @return A `causal_tbl` object
#'
#' @examples
#' causal_tbl(
#' data <- causal_tbl(
#' milk_first = c(0, 1, 0, 1, 1, 0, 0, 1),
#' guess = c(0, 1, 0, 1, 1, 0, 0, 1)
#' )
#' is_causal_tbl(data)
#' print(data)
#'
#' @export
causal_tbl <- function(..., .outcome=NULL, .treatment=NULL) {
Expand All @@ -142,10 +144,10 @@ causal_tbl <- function(..., .outcome=NULL, .treatment=NULL) {


#' @describeIn causal_tbl Coerce a data frame to a `causal_tbl`
#' @param x A data frame to be coerced
#' @param x A data frame to be checked or coerced
#' @export
as_causal_tbl <- function(x) {
if (inherits(x, "causal_tbl")) {
if (is_causal_tbl(x)) {
x
} else if (is.data.frame(x)) {
reconstruct.causal_tbl(x)
Expand All @@ -154,8 +156,14 @@ as_causal_tbl <- function(x) {
}
}

#' @describeIn causal_tbl Return `TRUE` if a data frame is a `causal_tbl`
#' @export
is_causal_tbl <- function(x) {
inherits(x, "causal_tbl")
}

assert_causal_tbl <- function(data, arg) {
if (!inherits(data, "causal_tbl")) {
if (!is_causal_tbl(data)) {
cli_abort("{.arg {deparse(substitute(data))}} must be a {.cls causal_tbl}.",
call=parent.frame())
}
Expand Down
90 changes: 90 additions & 0 deletions R/indices.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#' @describeIn indices Construct an `indices` list with minimal checks
#' @export
new_indices <- function(x = list()) {
if (!is.list(x)) {
cli_abort("{.arg x} must be a list.")
}

vctrs::new_vctr(x, class="indices")
}

#' Construct a list of indices
#'
#' An `indices` list is a list with integer vector entries which represent
#' indices in some other object (like a data frame). Generically it can
#' be used to represent a graph structure, such as an interference network
#' or a collection of matched objects.
#' The main feature of `indices` is that the index references are preserved
#' through slicing and reordering. Indices that no longer refer to elements
#' (because of subsetting) are set to NA.
#'
#' @param x
#' * For `indices()` and `new_indices()`: A list of indices
#' * For `is_indices()`: An object to test
#' * For `as_indices()`: An object to coerce
#'
#' @returns An `indices` object.
#'
#' @examples
#' idx <- indices(list(2, c(1, NA, 3), 2))
#' print(idx)
#' idx[1:2] # subsetting
#' idx[c(2, 1, 3)] # reordering
#' @export
indices <- function(x = list()) {
# convert each element to an integer
x <- relist(vctrs::vec_cast(unlist(x), integer(), x_arg="x"), x)
new_indices(x)
}

#' @describeIn indices Return `TRUE` if an object is an `indices` list
#' @export
is_indices <- function(x) {
inherits(x, "indices")
}

#' @describeIn indices Coerce an object to an `indices` list
#' @export
as_indices <- function(x) {
vctrs::vec_cast(x, new_indices())
}


# printing
#' @export
format.indices <- function(x, ...) {
vapply(vctrs::vec_data(x), format_index_line, "")
}
format_index_line <- function(y) {
paste0("(", paste0(formatC(y[!is.na(y)]), collapse=","), ")")
}

# vctrs -------------------------------------------------------------------

#' @export
`[.indices` <- function(x, i) {
lookup <- match(seq_along(x), i)
out <- NextMethod()
for (j in seq_along(out)) {
out[[j]] = lookup[out[[j]]]
}
out
}

#' @importFrom vctrs vec_ptype_abbr
#' @method vec_ptype_abbr indices
#' @export
vec_ptype_abbr.indices <- function(x, ...) {
"idx" # nocov
}

#' @importFrom vctrs vec_ptype2
#' @export
vec_ptype2.list.indices <- function(x, y, ...) list()
#' @export
vec_ptype2.indices.list <- function(x, y, ...) list()
#' @importFrom vctrs vec_cast
#' @export
vec_cast.list.indices <- function(x, y, ...) as.list(x)
#' @export
vec_cast.indices.list <- function(x, y, ...) indices(x)
11 changes: 9 additions & 2 deletions man/causal_tbl.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions man/indices.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions tests/testthat/_snaps/indices.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# indices printing

Code
print(idx)
Output
<indices[3]>
[1] (2) (1,3) (2)

---

Code
str(idx)
Output
idx [1:3]
$ : int 2
$ : int [1:3] 1 NA 3
$ : int 2

36 changes: 36 additions & 0 deletions tests/testthat/test-indices.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
test_that("indices constructor", {
idx <- indices(list(2, c(1, NA, 3), 2))
expect_s3_class(idx, "indices")
expect_type(idx, "list")
expect_true(is_indices(idx))

expect_error(indices(list("a")), "character")
expect_error(indices(5), "must be a list")
})

test_that("indices conversion", {
idx <- indices(list(2, c(1, NA, 3), 2))

expect_s3_class(as_indices(as.list(idx)), "indices")
expect_type(as.list(idx), "list")
expect_type(c(idx, list()), "list")
expect_type(c(list(), idx), "list")
})

test_that("indices slicing", {
idx <- indices(list(2, c(1, NA, 3), 2))

expect_equal(idx[1:3], idx)
expect_equal(idx[1:2],
indices(list(2, c(1, NA, NA))))
expect_equal(idx[2:1],
indices(list(c(2, NA, NA), 1)))
})


test_that("indices printing", {
idx <- indices(list(2, c(1, NA, 3), 2))

expect_snapshot(print(idx))
expect_snapshot(str(idx))
})

0 comments on commit 14c3884

Please sign in to comment.