Skip to content

Commit

Permalink
Multiple coefficients from single declare_estimator step.
Browse files Browse the repository at this point in the history
  • Loading branch information
nfultz committed Feb 1, 2018
1 parent 81e46b5 commit 5a16390
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 51 deletions.
39 changes: 31 additions & 8 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,9 @@ currydata <- function(FUN, dots, addDataArg=TRUE,strictDataParam=TRUE) {
}
}

default_declaration_validation_callback <- function(decl, dots) decl

#' @importFrom rlang enquo
declaration_template <- function(..., handler, label=NULL){
#message("Declared")
d <- enquo(handler);

dots <- quos(...,label=!!label)
this <- attributes(sys.function())
Expand All @@ -61,17 +58,30 @@ declaration_template <- function(..., handler, label=NULL){



ret <- structure(currydata(handler, dots, strictDataParam=this$strictDataParam),
ret <- build_step(currydata(handler, dots, strictDataParam=this$strictDataParam),
handler=handler,
dots=dots,
label=label,
step_type=this$step_type,
causal_type=this$causal_type,
call=match.call() )
call=match.call())

if(is.function(this$validation)) ret <- this$validation(ret, handler, dots, label)
if(has_validation_fn(handler)) ret <- validate(handler, ret, dots, label)

ret
}

build_step <- function(curried_fn, handler, dots, label, step_type, causal_type, call){
structure(curried_fn,
handler=handler,
dots=dots,
label=label,
step_type=step_type,
causal_type=causal_type,
call=call,
class=c("design_step", "function"))
}

make_declarations <- function(default_handler, step_type, causal_type='dgp', default_label, validation=NULL, strictDataParam=TRUE) {

declaration <- declaration_template
Expand All @@ -81,14 +91,27 @@ make_declarations <- function(default_handler, step_type, causal_type='dgp', def
if(!missing(default_label)) formals(declaration)$label <- default_label

structure(declaration,
class=c('Declared', 'function'),
class=c('declaration', 'function'),
step_type=step_type,
causal_type=causal_type,
validation=validation,
strictDataParam=strictDataParam)
}

########################################################################

validation_fn <- function(f){
attr(f, "validation_fn")
}

`validation_fn<-` <- with_validation_fn <- function(f, v) {
attr(f, "validation_fn") <- v
v
}

has_validation_fn <- function(f){
is.function(validation_fn(f))
}

validate <- function(handler, ret, dots, label) {
validation_fn(handler)(ret, dots, label)
}
28 changes: 12 additions & 16 deletions R/declare_design.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# f(mtcars)

#' @importFrom rlang quos lang_fn lang_modify eval_tidy
callquos_to_step <- function(step_call) {
callquos_to_step <- function(step_call, label="") {
## this function allows you to put any R expression
## such a dplyr::mutate partial call
## into the causal order, i.e.
Expand All @@ -43,17 +43,18 @@ callquos_to_step <- function(step_call) {

dots <- quos(!!data_name := data, !!!dots)

quo <- quo(currydata(fun, !!!dots))
#quo <- quo(currydata(fun, !!!dots))

curried <- currydata(fun, dots)

# curried <- eval_tidy(quo)

build_step(curried, handler=fun, dots=dots, label, step_type="wrapped", causal_type="dgp", call=step_call)

structure(curried,
call = step_call[[2]],
step_type = "wrapped",
causal_type="dgp")
# structure(curried,
# call = step_call[[2]],
# step_type = "wrapped",
# causal_type="dgp")

}

Expand Down Expand Up @@ -129,28 +130,23 @@ callquos_to_step <- function(step_call) {
declare_design <- function(...) {

qs <- quos(...)
qs <- maybe_add_labels(qs)
qnames <- names(qs)

ret <- structure(list(), class="design")
ret <- structure(vector("list", length(qs)), call=match.call(), class="design")

names(ret)[qnames != ""] <- qnames[qnames != ""]

if(getOption("DD.debug.declare_design", FALSE)) browser()

for(i in seq_along(qs)) {


#wrap step is nasty, converts partial call to curried function
ret[[i]] <- tryCatch(
eval_tidy(qs[[i]]),
error = function(e) callquos_to_step(qs[[i]])
error = function(e) callquos_to_step(qs[[i]], qnames[[i]])
)

if(qnames[[i]] != "") {
attr(ret[[i]], "label") <- qnames[[i]]
names(ret)[i] <- qnames[[i]]
}



}

# Special case for initializing with a data.frame
Expand Down
46 changes: 27 additions & 19 deletions R/declare_estimand.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,11 @@
#' my_estimand_custom(df)
#'

declare_estimand <- make_declarations(estimand_function_default, "estimand", causal_type="estimand", default_label="my_estimand",
validation=function(ret, handler, dots, label){
force(ret)
# add ... labels at build time
if(identical(estimand_function_default, handler)){
dotnames <- names(dots)

maybeDotLabel <- dotnames[! dotnames %in% c("", names(formals(handler)) )]
if(length(maybeDotLabel) == 1){
attr(ret, "steplabel") <- attr(ret, "label")
attr(ret, "label") <- maybeDotLabel[1]
}
}
ret
})
declare_estimand <- make_declarations(estimand_function_default, "estimand", causal_type="estimand", default_label="my_estimand")


#' @importFrom rlang eval_tidy quos is_quosure
estimand_function_default <- function(data, ..., subset = NULL, label) {
estimand_function_default <- function(data, ..., subset = NULL, coefficient_names=FALSE, label) {
options <- quos(...)
if(names(options)[1] == "") names(options)[1] <- label

Expand All @@ -75,8 +61,30 @@ estimand_function_default <- function(data, ..., subset = NULL, label) {
}
ret <- simplify2array(ret)

data.frame(estimand_label=names(options),
estimand=ret,
stringsAsFactors = FALSE)
if(coefficient_names){
data.frame(estimand_label=label,
coefficient_name=names(options),
estimand=ret,
stringsAsFactors = FALSE)


} else {
data.frame(estimand_label=names(options),
estimand=ret,
stringsAsFactors = FALSE)
}
}

attr(estimand_function_default, "validation_fn") <- function(ret, dots, label){
force(ret)
# add ... labels at build time
dotnames <- names(dots)

maybeDotLabel <- dotnames[! dotnames %in% c("", names(formals(estimand_function_default)) )]
if(length(maybeDotLabel) == 1){
attr(ret, "steplabel") <- attr(ret, "label")
attr(ret, "label") <- maybeDotLabel[1]
}

ret
}
4 changes: 2 additions & 2 deletions R/declare_estimator.R
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ estimator_handler <- function(data, ...,
stop("Please provide ", lbl, " function with a data argument.")
}

estimand_label <- switch(class(estimand), "character"=estimand, "function"=attributes(estimand)$label, NULL=NULL, warning("Did not match class of `estimand`"))
estimand_label <- switch(class(estimand)[1], "character"=estimand, "design_step"=attributes(estimand)$label, NULL=NULL, warning("Did not match class of `estimand`"))

# estimator_function_internal <- function(data) {
args <- quos(...)
Expand Down Expand Up @@ -147,7 +147,7 @@ fit2tidy <- function(fit, coefficient_name = NULL) {
colnames(return_data) <- c("coefficient_name","est", "se", "p", "ci_lower", "ci_upper")


if (!is.null(coefficient_name)) {
if (is.character(coefficient_name)) {
return_data <- return_data[return_data$coefficient_name %in% coefficient_name, ,drop = FALSE]
}

Expand Down
8 changes: 5 additions & 3 deletions R/diagnose_design.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,17 @@ diagnose_design_single_design <-
merge(
estimands_df,
estimates_df,
by = c("sim_ID", intersect("estimand_label", colnames(estimates_df))),
by = c("sim_ID",
"estimand_label" %i% colnames(estimates_df),
"coefficient_name" %i% colnames(estimands_df) %i% colnames(estimates_df)),
all = TRUE,
sort = FALSE
)
}

calculate_diagnosands <-
function(simulations_df, diagnosands){
group_by_set <- intersect(colnames(simulations_df), c("estimand_label", "estimator_label"))
group_by_set <- colnames(simulations_df) %i% c("estimand_label", "estimator_label", "coefficient_name")
group_by_list <- simulations_df[, group_by_set, drop=FALSE]

labels_df <- split(group_by_list, group_by_list, drop = TRUE)
Expand Down Expand Up @@ -241,7 +243,7 @@ diagnose_design_single_design <-
# replicate(bootstrap_sims, expr = boot_function(), simplify = FALSE)
# diagnosand_replicates <- do.call(rbind, diagnosand_replicates)

group_by_set <- intersect( colnames(diagnosand_replicates), c("estimand_label", "estimator_label"))
group_by_set <- colnames(diagnosand_replicates) %i% c("estimand_label", "estimator_label", "coefficient_name")
group_by_list <- diagnosand_replicates[, group_by_set, drop=FALSE]

labels_df <- split(group_by_list, group_by_list, drop=TRUE)
Expand Down
4 changes: 2 additions & 2 deletions R/diagnosis_helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ print.summary.diagnosis <- function(x, ...) {
cat("\nResearch design diagnosis\n\n")
print_diagnosis <- x
names(x) <-
gsub("(^|[[:space:]])([[:alpha:]])",
"\\1\\U\\2",
gsub("\\b(se[(]|sd |rmse|[[:alpha:]])",
"\\U\\1",
gsub("_", " ", names(x)),
perl = TRUE)
print(x, row.names = FALSE)
Expand Down
25 changes: 25 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,29 @@ wrap_step <- function(...) {

}



maybe_add_labels <- function(quotations){

labeller <- function(quotation, lbl) {
cx <- quotation[[2]]
if(is.call(cx) && is.symbol(cx[[1]]) && ! "label" %in% names(cx) && lbl != ""){

f <- match.fun(cx[[1]])
if("declaration" %in% class(f) && "label" %in% names(formals(f))){
quotation[[2]][["label"]] <- lbl
}
}
quotation
}

for(i in seq_along(quotations)){
quotations[[i]] <- labeller(quotations[[i]], names(quotations)[i])
}

quotations
}

# If <= 5 uniques, table it, ow descriptives if numeric-ish, ow number of levels.
describe_variable <- function(x) {

Expand Down Expand Up @@ -133,3 +156,5 @@ get_unique_variables_by_level <- function(data, ID_label, superset=NULL) {

return(level_variables)
}

`%i%` <- intersect
2 changes: 1 addition & 1 deletion tests/testthat/test-factorial.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ test_that("Factorial", {

expect_equal(diagnosis %>% get_simulations %>% dim, c(2, 10))

expect_equal(diagnosis %>% get_diagnosands %>% dim, c(1,10))
expect_equal(diagnosis %>% get_diagnosands %>% dim, c(1,11))

})
28 changes: 28 additions & 0 deletions tests/testthat/test-multiple-coefficients.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
context("Factorial Design")

test_that("Factorial", {

alpha = 1
beta = 3

my_design <- declare_design(
my_pop = declare_population(N=30, noise=rnorm(N, mean = 0, sd=1), X=1:N, Y = alpha + beta*X + noise)
,
theta = declare_estimand(
`(Intercept)`=alpha,
X = beta,
coefficient_names = TRUE
)
,
OLS = declare_estimator(Y~X, model=lm, estimand="theta", coefficient_name=NULL)
)



diagnosis <- diagnose_design(my_design, sims = 2, bootstrap = FALSE, parallel = FALSE)

expect_equal(diagnosis %>% get_simulations %>% dim, c(4, 10))

expect_equal(diagnosis %>% get_diagnosands %>% dim, c(2,11))

})

0 comments on commit 5a16390

Please sign in to comment.