Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1667 from Rdatatable/groupingsets
RC - Grouping Sets, rollup, cube. #1377
- Loading branch information
Showing
5 changed files
with
411 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
rollup <- function(x, ...) { | ||
UseMethod("rollup") | ||
} | ||
rollup.data.table <- function(x, j, by, .SDcols, id = FALSE, ...) { | ||
# input data type basic validation | ||
if (!is.data.table(x)) | ||
stop("Argument 'x' must be data.table object") | ||
if (!is.character(by)) | ||
stop("Argument 'by' must be character vector of column names used in grouping.") | ||
if (!is.logical(id)) | ||
stop("Argument 'id' must be logical scalar.") | ||
# generate grouping sets for rollup | ||
sets = lapply(length(by):0, function(i) by[0:i]) | ||
# redirect to workhorse function | ||
jj = substitute(j) | ||
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj) | ||
} | ||
|
||
cube <- function(x, ...) { | ||
UseMethod("cube") | ||
} | ||
cube.data.table <- function(x, j, by, .SDcols, id = FALSE, ...) { | ||
# input data type basic validation | ||
if (!is.data.table(x)) | ||
stop("Argument 'x' must be data.table object") | ||
if (!is.character(by)) | ||
stop("Argument 'by' must be character vector of column names used in grouping.") | ||
if (!is.logical(id)) | ||
stop("Argument 'id' must be logical scalar.") | ||
# generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497 | ||
n = length(by) | ||
keepBool = sapply(2L^(1:n - 1L), function(k) rep(c(FALSE, TRUE), each=k, times=(2L^n / (2L*k)))) | ||
sets = lapply((2L^n):1, function(j) by[keepBool[j, ]]) | ||
# redirect to workhorse function | ||
jj = substitute(j) | ||
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj) | ||
} | ||
|
||
groupingsets <- function(x, ...) { | ||
UseMethod("groupingsets") | ||
} | ||
groupingsets.data.table <- function(x, j, by, sets, .SDcols, id = FALSE, jj, ...) { | ||
# input data type basic validation | ||
if (!is.data.table(x)) | ||
stop("Argument 'x' must be data.table object") | ||
if (ncol(x) < 1L) | ||
stop("Argument 'x' is 0 column data.table, no measure to apply grouping over.") | ||
if (length(names(x)) != uniqueN(names(x))) | ||
stop("data.table must not contains duplicate column names.") | ||
if (!is.character(by)) | ||
stop("Argument 'by' must be character vector of column names used in grouping.") | ||
if (length(by) != uniqueN(by)) | ||
stop("Argument 'by' must have unique column names for grouping.") | ||
if (!is.list(sets) || !all(sapply(sets, is.character))) | ||
stop("Argument 'sets' must be a list of character vectors.") | ||
if (!is.logical(id)) | ||
stop("Argument 'id' must be logical scalar.") | ||
# logic constraints validation | ||
if (!all((sets.all.by <- unique(unlist(sets))) %chin% by)) | ||
stop(sprintf("All columns used in 'sets' argument must be in 'by' too. Columns used in 'sets' but not present in 'by': %s.", paste(setdiff(sets.all.by, by), collapse=", "))) | ||
if (id && "grouping" %chin% names(x)) | ||
stop("When using `id=TRUE` the 'x' data.table must not have column named 'grouping'.") | ||
if (!all(sapply(sets, function(x) length(x)==uniqueN(x)))) | ||
stop("Character vectors in 'sets' list must not have duplicated column names within single grouping set.") | ||
if (!identical(lapply(sets, sort), unique(lapply(sets, sort)))) | ||
warning("Double counting is going to happen. Argument 'sets' should be unique without taking order into account, unless you really want double counting, then get used to that warning. Otherwise `sets=unique(lapply(sets, sort))` will do the trick.") | ||
# input arguments handling | ||
jj = if (!missing(jj)) jj else substitute(j) | ||
av = all.vars(jj, TRUE) | ||
if (":=" %chin% av) | ||
stop("Expression passed to grouping sets function must not update by reference. Use ':=' on results of your grouping function.") | ||
if (missing(.SDcols)) | ||
.SDcols = if (".SD" %chin% av) setdiff(names(x), by) else NULL | ||
# 0 rows template data.table to keep colorder and type | ||
if (length(by)) { | ||
empty = if (length(.SDcols)) x[0L, eval(jj), by, .SDcols=.SDcols] else x[0L, eval(jj), by] | ||
} else { | ||
empty = if (length(.SDcols)) x[0L, eval(jj), .SDcols=.SDcols] else x[0L, eval(jj)] | ||
if (!is.data.table(empty)) empty = setDT(list(empty)) # improve after #648, see comment in aggr.set | ||
} | ||
if (id && "grouping" %chin% names(empty)) # `j` could have been evaluated to `grouping` field | ||
stop("When using `id=TRUE` the 'j' expression must not evaluate to column named 'grouping'.") | ||
if (length(names(empty)) != uniqueN(names(empty))) | ||
stop("There exists duplicated column names in the results, ensure the column passed/evaluated in `j` and those in `by` are not overlapping.") | ||
# adding grouping column to template - aggregation level identifier | ||
if (id) { | ||
set(empty, j = "grouping", value = integer()) | ||
setcolorder(empty, c("grouping", by, setdiff(names(empty), c("grouping", by)))) | ||
} | ||
# workaround for rbindlist fill=TRUE on integer64 #1459 | ||
int64.cols = vapply(empty, inherits, logical(1), "integer64") | ||
int64.cols = names(int64.cols)[int64.cols] | ||
if (length(int64.cols) && !requireNamespace("bit64", quietly=TRUE)) | ||
stop("Using integer64 class columns require to have 'bit64' package installed.") | ||
int64.by.cols = intersect(int64.cols, by) | ||
# aggregate function called for each grouping set | ||
aggregate.set <- function(by.set) { | ||
if (length(by.set)) { | ||
r = if (length(.SDcols)) x[, eval(jj), by.set, .SDcols=.SDcols] else x[, eval(jj), by.set] | ||
} else { | ||
r = if (length(.SDcols)) x[, eval(jj), .SDcols=.SDcols] else x[, eval(jj)] | ||
# workaround for grand total single var as data.table too, change to drop=FALSE after #648 solved | ||
if (!is.data.table(r)) r = setDT(list(r)) | ||
} | ||
if (id) { | ||
# integer bit mask of aggregation levels: http://www.postgresql.org/docs/9.5/static/functions-aggregate.html#FUNCTIONS-GROUPING-TABLE | ||
set(r, j = "grouping", value = strtoi(paste(c("1", "0")[by %chin% by.set + 1L], collapse=""), base=2L)) | ||
} | ||
if (length(int64.by.cols)) { | ||
# workaround for rbindlist fill=TRUE on integer64 #1459 | ||
missing.int64.by.cols = setdiff(int64.by.cols, by.set) | ||
if (length(missing.int64.by.cols)) r[, (missing.int64.by.cols) := bit64::as.integer64(NA)] | ||
} | ||
r | ||
} | ||
# actually processing everything here | ||
rbindlist(c( | ||
list(empty), # 0 rows template for colorder and type | ||
lapply(sets, aggregate.set) # all aggregations | ||
), use.names=TRUE, fill=TRUE) | ||
} |
Oops, something went wrong.