Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5e4dc8d
Initial CRAN release preparation commit
andrewherren Jan 23, 2025
2192fd2
Fixing CRAN check warning with generic predict overload
andrewherren Jan 23, 2025
f02ed71
Updated R package to pass CRAN checks
andrewherren Jan 23, 2025
e3b3481
Updated BART R API
andrewherren Jan 24, 2025
1c60f55
Propagate parameter name changes in BART and BCF APIs
andrewherren Jan 24, 2025
23f7eee
Updated parameter names in examples and vignettes and added to inst/C…
andrewherren Jan 24, 2025
d7efd51
Updated aspirational README ahead of CRAN submission
andrewherren Jan 26, 2025
899beba
Updated docs and CRAN bootstrap script
andrewherren Jan 29, 2025
93d93c2
Updated R docstrings
andrewherren Jan 29, 2025
4fd36be
Updated calibration function name
andrewherren Jan 29, 2025
36aa106
Updated interface and function names
andrewherren Jan 29, 2025
0dc3928
Simplifying preprocessing API (and temporarily disabling deprecated p…
andrewherren Jan 29, 2025
0ff1c7f
Removed convertBARTStateToJson function
andrewherren Jan 29, 2025
4ce1510
Switched from convert to save for JSON model
andrewherren Jan 29, 2025
5096080
Adding createBCFModelFromCombinedJson function
andrewherren Jan 29, 2025
5167b64
Updated R interface
andrewherren Jan 30, 2025
3c9721c
Updated R interface
andrewherren Jan 30, 2025
0cfe8c4
Updated R interface
andrewherren Jan 30, 2025
78be2e5
Updated R interface to make variance samplers consistent with functio…
andrewherren Jan 30, 2025
573265d
Updated GHA workflow for R CMD Check
andrewherren Jan 30, 2025
315fc59
Updating R docs to include more examples
andrewherren Jan 30, 2025
d58bda8
Updated RFX docs
andrewherren Jan 30, 2025
86fd507
Updated R serialization docs
andrewherren Jan 30, 2025
12299e1
Updated variance and preprocessing docs
andrewherren Jan 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
^.*\.Rproj$
^\.Rproj\.user$
^cran-comments\.md$
2 changes: 1 addition & 1 deletion .github/workflows/r-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:

- name: Create a CRAN-ready version of the R package
run: |
Rscript cran-bootstrap.R 0
Rscript cran-bootstrap.R 0 0

- uses: r-lib/actions/check-r-package@v2
with:
Expand Down
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
Package: stochtree
Title: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference
Version: 0.0.1
Version: 0.1.0
Authors@R:
c(
person("Drew", "Herren", email = "drewherrenopensource@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0000-0003-4109-6611")),
person("Richard", "Hahn", role = "aut"),
person("Jared", "Murray", role = "aut"),
person("Carlos", "Carvalho", role = "aut"),
person("Jingyu", "He", role = "aut")
person("Jingyu", "He", role = "aut"),
person("stochtree contributors", role = c("cph"))
)
Description: Stochastic tree ensembles (XBART and BART) for supervised learning and causal inference.
License: MIT + file LICENSE
Expand Down
34 changes: 12 additions & 22 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
# Generated by roxygen2: do not edit by hand

S3method(getRandomEffectSamples,bartmodel)
S3method(getRandomEffectSamples,bcf)
S3method(getRandomEffectSamples,bcfmodel)
S3method(predict,bartmodel)
S3method(predict,bcf)
S3method(predict,bcfmodel)
export(bart)
export(bcf)
export(calibrate_inverse_gamma_error_variance)
export(calibrateInverseGammaErrorVariance)
export(computeForestLeafIndices)
export(computeForestLeafVariances)
export(computeMaxLeafIndex)
export(convertBARTModelToJson)
export(convertBCFModelToJson)
export(computeForestMaxLeafIndex)
export(convertPreprocessorToJson)
export(createBARTModelFromCombinedJson)
export(createBARTModelFromCombinedJsonString)
export(createBARTModelFromJson)
export(createBARTModelFromJsonFile)
export(createBARTModelFromJsonString)
export(createBCFModelFromCombinedJson)
export(createBCFModelFromCombinedJsonString)
export(createBCFModelFromJson)
export(createBCFModelFromJsonFile)
export(createBCFModelFromJsonString)
export(createCppJson)
export(createCppJsonFile)
export(createCppJsonString)
export(createCppRNG)
export(createForest)
export(createForestContainer)
export(createForestCovariates)
export(createForestCovariatesFromMetadata)
export(createForestDataset)
export(createForestModel)
export(createForestSamples)
export(createOutcome)
export(createPreprocessorFromJson)
export(createPreprocessorFromJsonString)
export(createRNG)
export(createRandomEffectSamples)
export(createRandomEffectsDataset)
export(createRandomEffectsModel)
Expand All @@ -48,35 +45,28 @@ export(loadRandomEffectSamplesCombinedJsonString)
export(loadRandomEffectSamplesJson)
export(loadScalarJson)
export(loadVectorJson)
export(oneHotEncode)
export(oneHotInitializeAndEncode)
export(orderedCatInitializeAndPreprocess)
export(orderedCatPreprocess)
export(preprocessParams)
export(preprocessPredictionData)
export(preprocessPredictionDataFrame)
export(preprocessPredictionMatrix)
export(preprocessTrainData)
export(preprocessTrainDataFrame)
export(preprocessTrainMatrix)
export(resetActiveForest)
export(resetForestModel)
export(resetRandomEffectsModel)
export(resetRandomEffectsTracker)
export(rootResetActiveForest)
export(rootResetRandomEffectsModel)
export(rootResetRandomEffectsTracker)
export(sample_sigma2_one_iteration)
export(sample_tau_one_iteration)
export(sampleGlobalErrorVarianceOneIteration)
export(sampleLeafVarianceOneIteration)
export(saveBARTModelToJson)
export(saveBARTModelToJsonFile)
export(saveBARTModelToJsonString)
export(saveBCFModelToJson)
export(saveBCFModelToJsonFile)
export(saveBCFModelToJsonString)
export(savePreprocessorToJsonString)
importFrom(R6,R6Class)
importFrom(stats,coef)
importFrom(stats,lm)
importFrom(stats,model.matrix)
importFrom(stats,predict)
importFrom(stats,qgamma)
importFrom(stats,resid)
importFrom(stats,rnorm)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# stochtree 0.1.0

* Initial CRAN submission.
375 changes: 159 additions & 216 deletions R/bart.R

Large diffs are not rendered by default.

659 changes: 437 additions & 222 deletions R/bcf.R

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions R/calibration.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#' X <- matrix(runif(n*p), ncol = p)
#' y <- 10*X[,1] - 20*X[,2] + rnorm(n)
#' nu <- 3
#' lambda <- calibrate_inverse_gamma_error_variance(y, X, nu = nu)
#' lambda <- calibrateInverseGammaErrorVariance(y, X, nu = nu)
#' sigma2hat <- mean(resid(lm(y~X))^2)
#' mean(var(y)/rgamma(100000, nu, rate = nu*lambda) < sigma2hat)
calibrate_inverse_gamma_error_variance <- function(y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE) {
calibrateInverseGammaErrorVariance <- function(y, X, W = NULL, nu = 3, quant = 0.9, standardize = TRUE) {
# Compute regression basis
if (!is.null(W)) basis <- cbind(X, W)
else basis <- X
Expand Down
8 changes: 4 additions & 4 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,8 @@ json_load_forest_container_cpp <- function(forest_samples, json_filename) {
invisible(.Call(`_stochtree_json_load_forest_container_cpp`, forest_samples, json_filename))
}

output_dimension_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_output_dimension_forest_container_cpp`, forest_samples)
leaf_dimension_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_leaf_dimension_forest_container_cpp`, forest_samples)
}

is_leaf_constant_forest_container_cpp <- function(forest_samples) {
Expand Down Expand Up @@ -464,8 +464,8 @@ predict_raw_active_forest_cpp <- function(active_forest, dataset) {
.Call(`_stochtree_predict_raw_active_forest_cpp`, active_forest, dataset)
}

output_dimension_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_output_dimension_active_forest_cpp`, active_forest)
leaf_dimension_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_leaf_dimension_active_forest_cpp`, active_forest)
}

average_max_depth_active_forest_cpp <- function(active_forest) {
Expand Down
20 changes: 20 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,14 @@ RandomEffectsDataset <- R6::R6Class(
#'
#' @return `ForestDataset` object
#' @export
#'
#' @examples
#' covariate_matrix <- matrix(runif(10*100), ncol = 10)
#' basis_matrix <- matrix(rnorm(3*100), ncol = 3)
#' weight_vector <- rnorm(100)
#' forest_dataset <- createForestDataset(covariate_matrix)
#' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix)
#' forest_dataset <- createForestDataset(covariate_matrix, basis_matrix, weight_vector)
createForestDataset <- function(covariates, basis=NULL, variance_weights=NULL){
return(invisible((
ForestDataset$new(covariates, basis, variance_weights)
Expand All @@ -240,6 +248,11 @@ createForestDataset <- function(covariates, basis=NULL, variance_weights=NULL){
#'
#' @return `Outcome` object
#' @export
#'
#' @examples
#' X <- matrix(runif(10*100), ncol = 10)
#' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
#' outcome <- createOutcome(y)
createOutcome <- function(outcome){
return(invisible((
Outcome$new(outcome)
Expand All @@ -254,6 +267,13 @@ createOutcome <- function(outcome){
#'
#' @return `RandomEffectsDataset` object
#' @export
#'
#' @examples
#' rfx_group_ids <- sample(1:2, size = 100, replace = TRUE)
#' rfx_basis <- matrix(rnorm(3*100), ncol = 3)
#' weight_vector <- rnorm(100)
#' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis)
#' rfx_dataset <- createRandomEffectsDataset(rfx_group_ids, rfx_basis, weight_vector)
createRandomEffectsDataset <- function(group_labels, basis, variance_weights=NULL){
return(invisible((
RandomEffectsDataset$new(group_labels, basis, variance_weights)
Expand Down
Loading
Loading