diff --git a/R/bart.R b/R/bart.R index a3e70bcd..a6668ce7 100644 --- a/R/bart.R +++ b/R/bart.R @@ -321,8 +321,18 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL, if (is.null(sigma2_init)) sigma2_init <- pct_var_sigma2_init*var(resid_train) if (is.null(variance_forest_init)) variance_forest_init <- pct_var_variance_forest_init*var(resid_train) if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean) - if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean) - current_leaf_scale <- as.matrix(sigma_leaf_init) + if (has_basis) { + if (ncol(W_train) > 1) { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(W_train)) + current_leaf_scale <- sigma_leaf_init + } else { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean) + current_leaf_scale <- as.matrix(sigma_leaf_init) + } + } else { + if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean) + current_leaf_scale <- as.matrix(sigma_leaf_init) + } current_sigma2 <- sigma2_init # Determine leaf model type diff --git a/tools/debug/multivariate_bart_debug.R b/tools/debug/multivariate_bart_debug.R new file mode 100644 index 00000000..da376696 --- /dev/null +++ b/tools/debug/multivariate_bart_debug.R @@ -0,0 +1,42 @@ +library(stochtree) + +# Generate the data +n <- 500 +p_x <- 10 +p_w <- 2 +snr <- 3 +X <- matrix(runif(n*p_x), ncol = p_x) +W <- matrix(runif(n*p_w), ncol = p_w) +f_XW <- ( + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) +) +noise_sd <- sd(f_XW) / snr +y <- f_XW + rnorm(n, 0, 1)*noise_sd + +# Split data into test and train sets +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_test <- as.data.frame(X[test_inds,]) +X_train <- as.data.frame(X[train_inds,]) +W_test <- W[test_inds,] +W_train <- W[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Sample BART model +num_gfr <- 10 +num_burnin <- 0 +num_mcmc <- 100 +num_samples <- num_gfr + num_burnin + num_mcmc +bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = F, num_trees_mean = 100) +bart_model_warmstart <- stochtree::bart( + X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, + params = bart_params +)