Skip to content

Commit

Permalink
Merge pull request #12 from Laksafoss/kija_dev
Browse files Browse the repository at this point in the history
Add more functions related to causal forest analysis
  • Loading branch information
kjakobse committed May 11, 2023
2 parents de94593 + e2bfa0d commit 5076f41
Show file tree
Hide file tree
Showing 18 changed files with 1,265 additions and 83 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Imports:
data.table,
doParallel,
dplyr,
forcats,
foreach,
ggplot2,
glue,
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ S3method(fct_confint,lm)
S3method(fct_confint,lms)
S3method(print,lms)
export(.datatable.aware)
export(CATESurface)
export(CForBenefit)
export(CausalForestDynamicSubgroups)
export(DiscreteCovariatesToOneHot)
export(RATEOmnibusTest)
export(RATETest)
export(braid_rows)
Expand Down
30 changes: 15 additions & 15 deletions R/kija_c_for_benefit.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
#' Calculates the c-for-benefit, as proposed by D. van Klaveren et
#' al. (2018), by matching patients based on patient characteristics.
#'
#' @param forest a causal forest object.
#' @param match character, "covariates" to match on covariates or "CATE" to
#' @param forest An object of class `causal_forest`, as returned by
#' \link[grf]{causal_forest}().
#' @param match character, `"covariates"` to match on covariates or `"CATE"` to
#' match on estimated CATE.
#' @param tau_hat_method character, "treatment" to calculate the expected
#' @param tau_hat_method character, `"treatment"` to calculate the expected
#' treatment effect in matched groups as the risk under treatment for the
#' treated subject minus the risk under control for the untreated
#' subject. "average" to calculate it as the average treatment effect of
#' subject. `"average"` to calculate it as the average treatment effect of
#' matched subject.
#' @param time_limit numeric, maximum allowed time to compute C-for-benefit. If
#' limit is reached, execution stops.
#' @param CI character, "none" for no confidence interval, "simple" to use a
#' normal approximation, and "bootstrap" to use the bootstrap.
#' @param CI character, `"none"` for no confidence interval, `"simple"` to use a
#' normal approximation, and `"bootstrap"` to use the bootstrap.
#' @param level numeric, confidence level of the confidence interval.
#' @param n_bootstraps numeric, number of bootstraps to use for the bootstrap
#' confidence interval computation.
Expand All @@ -23,18 +24,18 @@
#' should continue or be stopped.
#' @param verbose boolean, TRUE to display progress bar, FALSE to not display
#' progress bar.
#' @param method see MatchIt::matchit.
#' @param distance see MatchIt::matchit.
#' @param Y a vector of outcomes. If provided, replaces forest$Y.orig.
#' @param method see \link[MatchIt]{matchit}.
#' @param distance see \link[MatchIt]{matchit}.
#' @param Y a vector of outcomes. If provided, replaces `forest$Y.orig`.
#' @param W a vector of treatment assignment; 1 for active treatment; 0 for
#' control If provided, replaces forest$W.orig.
#' control If provided, replaces `forest$W.orig`.
#' @param X a matrix of patient characteristics. If provided, replaces
#' forest$X.orig.
#' @param p_0 a vector of outcome probabilities under control
#' @param p_1 a vector of outcome probabilities under active treatment
#' `forest$X.orig`.
#' @param p_0 a vector of outcome probabilities under control.
#' @param p_1 a vector of outcome probabilities under active treatment.
#' @param tau_hat a vector of individualized treatment effect predictions. If
#' provided, replaces forest$predictions.
#' @param ... additional arguments for MatchIt::matchit.
#' @param ... additional arguments for \link[MatchIt]{matchit}.
#'
#' @returns a list with the following components:
#'
Expand Down Expand Up @@ -71,7 +72,6 @@
#' and the baseline risk of the subject receiving control (tau_hat_method =
#' "treatment").
#'
#'
#' @author KIJA
#'
#' @examples
Expand Down
279 changes: 279 additions & 0 deletions R/kija_cate_surface.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
#' Calculate CATE on a surface in the covariate space
#'
#' Calculates CATE estimates from a causal forest object on a specified surface
#' within the covariate space.
#'
#' @param forest An object of class `causal_forest`, as returned by
#' \link[grf]{causal_forest}(). Alternatively, and object of class
#' `regression_forest`, as returned by \link[grf]{regression_forest}().
#' @param continuous_covariates character, continuous covariates to use for the
#' surface. Must match names in `forest$X.orig`.
#' @param discrete_covariates character, discrete covariates to use for the
#' surface. Note that discrete covariates are currently assumed to be one-hot
#' encoded with columns named `{fct_nm}_{lvl_nm}`. Names supplied to
#' discrete_covariates should match `fct_nm`.
#' @param estimate_variance boolean, If `TRUE`, the variance of CATE estimates
#' is computed.
#' @param grid list, points in which to predict CATE along continuous
#' covariates. Index i in the list should contain a numeric vectors with
#' either a single integer, specifying the number of equally spaced points
#' within the range of the i'th continuous covariate in which to calculate the
#' CATE, or a numeric vector with manually specified points in which to
#' calculate the CATE along the i'th continuous covariate. If all elements of
#' grid specify a number of points, this can be supplied using a numeric
#' vector. If the list is named, the names must match the continuous
#' covariates. grid will be reordered to match the order of
#' continuous_covariates.
#' @param fixed_covariate_fct Function applied to covariates not in the
#' sub-surface which returns the fixed value of the covariate used to
#' calculate the CATE. Must be specified in one of the following ways:
#' - A named function, e.g. `mean`.
#' - An anonymous function, e.g. \code{\(x) x + 1} or \code{function(x) x + 1}.
#' - A formula, e.g. \code{~ .x + 1}. You must use `.x` to refer to the
#' first argument. Only recommended if you require backward compatibility with
#' older versions of R.
#' - A string, integer, or list, e.g. `"idx"`, `1`, or `list("idx", 1)` which
#' are shorthand for \code{\(x) purrr::pluck(x, "idx")}, \code{\(x)
#' purrr::pluck(x, 1)}, and \code{\(x) purrr::pluck(x, "idx", 1)}
#' respectively. Optionally supply `.default` to set a default value if the
#' indexed element is `NULL` or does not exist.
#' @param other_discrete A data frame, data frame extension (e.g. a tibble), or
#' a lazy data frame (e.g. from dbplyr or dtplyr) with columns `covs` and
#' `lvl`. Used to specify the level of each discrete covariate to use when
#' calculating the CATE. assumes the use of one-hot encoding. `covs` must
#' contain the name of discrete covariates, and `lvl` the level to use. Set to
#' `NULL` if none of the fixed covariates are discrete using one-hot-encoding.
#' @param max_predict_size integer, maximum number of examples to predict at a
#' time. If the surface has more points than max_predict_size, the prediction
#' is split up into an appropriate number of chunks.
#' @param num_threads Number of threads used in training. If set to `NULL`, the
#' software automatically selects an appropriate amount.
#'
#' @return Tibble with the predicted CATE's on the specified surface in the
#' covariate space. The tibble has columns for each covariate used to train
#' the input forest, as well as columns output from
#' \link[grf]{predict.causal_forest}().
#'
#' @author KIJA
#'
#' @examples
#' n <- 1500
#' p <- 5
#' X <- matrix(rnorm(n * p), n, p) |> as.data.frame()
#' X_d <- data.frame(
#' X_d1 = factor(sample(1:5, n, replace = TRUE)),
#' X_d2 = factor(sample(1:5, n, replace = TRUE))
#' )
#' X_d <- DiscreteCovariatesToOneHot(X_d)
#' X <- cbind(X, X_d)
#' W <- rbinom(n, 1, 0.5)
#' event_prob <- 1 / (1 + exp(2 * (pmax(2 * X[, 1], 0) * W - X[, 2])))
#' Y <- rbinom(n, 1, event_prob)
#' cf <- grf::causal_forest(X, Y, W)
#' cate_surface <- CATESurface(
#' cf,
#' continuous_covariates = paste0("V", 1:4),
#' discrete_covariates = "X_d1",
#' grid = list(
#' V1 = 10,
#' V2 = 5,
#' V3 = -5:5,
#' V4 = 2
#' ),
#' other_discrete = data.frame(
#' covs = "X_d2",
#' lvl = "4"
#' )
#' )
#'
#' @export

CATESurface <- function(forest,
continuous_covariates,
discrete_covariates,
estimate_variance = TRUE,
grid = 100,
fixed_covariate_fct = median,
other_discrete = NULL,
max_predict_size = 100000,
num_threads = 2) {
# convert grid to list
if (is.numeric(grid) && length(grid) == 1) {
grid <- rep(grid, length(continuous_covariates))
}
grid <- as.list(grid)

# check input
stopifnot(
"'forest' must be an object of class causal_forest or regression_forest" =
any(c("causal_forest", "regression_forest") %in% class(forest))
)
stopifnot(
"continuous_covariates must be a character vector" =
is.character(continuous_covariates)
)
stopifnot(
"discrete_covariates must be a character vector" =
is.character(continuous_covariates)
)
stopifnot(
"estimate_variance must be a boolean (TRUE or FALSE)" =
isTRUE(estimate_variance) | isFALSE(estimate_variance)
)
stopifnot(
"'grid' must have the same length as 'continuous_covariates'" =
length(grid) == length(continuous_covariates)
)
lapply(
grid,
function(x) {
if (length(x) == 1 && !(is.numeric(x) && trunc(x) > 0.9)) {
stop("elements of 'grid' with length 1 must be positive integers.")
}
if (!(is.numeric(x))) {
stop("elements of 'grid' with length >1 must be numeric vectors.")
}
invisible(return(NULL))
}
)
if (!is.null(names(grid))) {
stopifnot(
"grid must be named after continuous_covariates." =
all(names(grid) %in% continuous_covariates) &&
all(continuous_covariates %in% names(grid))
)
}

# if named, reorder grid according to continuous_covariates
if (!is.null(names(grid))) {
grid <- grid[continuous_covariates]
}

# covariates in causal forest
if (is.null(colnames(forest$X.orig))) {
warning(
"Covariates used to train forest are unnamed. Names X_{colnr} are created.",
immediate. = TRUE
)
colnames(forest$X.orig) <- paste0("X_", seq_len(ncol(forest$X.orig)))
}
covariates <- colnames(forest$X.orig)

# formula with covariates
fmla <- formula(
paste0("~ 0 + ", paste0("`", covariates, "`", collapse = "+"))
)

# split covariates by use in sub-surface or not use in sub-surface
discrete_covariates_input <- discrete_covariates
discrete_covariates <- DiscreteCovariateNames(covariates, discrete_covariates)
other_covariates <- stringr::str_subset(
covariates,
paste0(
"^(",
paste0(c(discrete_covariates, continuous_covariates), collapse = "|"),
")"
),
negate = TRUE
)

# Generate grid of covariate values in which to evaluate CATE
X_continuous <- tibble::as_tibble(forest$X.orig) |>
dplyr::select(tidyselect::all_of(continuous_covariates))
data_grid <- purrr::map2(
X_continuous,
grid,
function(x, y) {
if (length(y) == 1) {
return(seq(range(x)[1], range(x)[2], length.out = y))
} else {
return(y)
}
}
)
discrete_values <- vector("list")
for(i in seq_along(discrete_covariates_input)) {
covariate_temp <- grep(
paste0("^", discrete_covariates_input[i]),
covariates,
value = TRUE
)
tibble_temp <- diag(1, length(covariate_temp), length(covariate_temp)) |>
tibble::as_tibble(.name_repair = "minimal") |>
rlang::set_names(nm = covariate_temp)
discrete_values[[i]] <- tibble_temp
}
names(discrete_values) <- discrete_covariates_input
data_grid <- c(data_grid, discrete_values)
X_other <- tibble::as_tibble(forest$X.orig) |>
dplyr::select(tidyselect::all_of(other_covariates))
fixed_values <- purrr::map_dbl(X_other, fixed_covariate_fct)
if(!is.null(other_discrete)) {
fixed_values[
grep(
paste0(
"^(",
paste0(other_discrete$covs, collapse = "|"),
")"
),
names(fixed_values),
value = TRUE
)
] <- 0
fixed_values[
grep(
paste0(
"^(",
paste0(
other_discrete$covs,
"_",
other_discrete$lvl,
"$",
collapse = "|"
),
")"
),
names(fixed_values),
value = TRUE
)
] <- 1
}
fixed_values <- as.list(fixed_values)
data_grid <- c(data_grid, fixed_values)
X_grid <- rlang::exec(tidyr::expand_grid, !!!data_grid)
X_grid <- tidyr::unnest(
X_grid,
cols = tidyselect::all_of(discrete_covariates_input)
)

# predict CATE in grid of points on surface
if(max_predict_size < nrow(X_grid)) {
X_grid_split <-
split(
seq_len(nrow(X_grid)), ceiling(seq_len(nrow(X_grid)) / max_predict_size)
)
tau_hat <- X_grid_split |>
purrr::map(
function(x) {
predict(
object = forest,
newdata = X_grid[x,],
estimate.variance = estimate_variance,
num.threads = num_threads)
}
) |>
purrr::list_rbind()
} else {
tau_hat <- predict(
object = forest,
newdata = X_grid,
estimate.variance = estimate_variance,
num.threads = num_threads)
}

# Return predicted CATE's on surface of covariate space
return(
X_grid |>
dplyr::bind_cols(tau_hat)
)
}
Loading

0 comments on commit 5076f41

Please sign in to comment.