Skip to content

Commit

Permalink
Fixed several bugs in RFX sampler and improved R interface to RFX
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewherren committed Jun 23, 2024
1 parent b790783 commit 63edb57
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 22 deletions.
40 changes: 30 additions & 10 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#' @param adaptive_coding Whether or not to use an "adaptive coding" scheme in which a binary treatment variable is not coded manually as (0,1) or (-1,1) but learned via parameters `b_0` and `b_1` that attach to the outcome model `[b_0 (1-Z) + b_1 Z] tau(X)`. This is ignored when Z is not binary. Default: T.
#' @param b_0 Initial value of the "control" group coding parameter. This is ignored when Z is not binary. Default: -0.5.
#' @param b_1 Initial value of the "treatment" group coding parameter. This is ignored when Z is not binary. Default: 0.5.
#' @param rfx_prior_var Prior (diagonals of the) covariance of the random effects model. Must be a vector of length `ncol(rfx_basis_train)`. Default: `rep(1, ncol(rfx_basis_train))`
#' @param random_seed Integer parameterizing the C++ random number generator. If not specified, the C++ random number generator is seeded according to `std::random_device`.
#' @param keep_burnin Whether or not "burnin" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
#' @param keep_gfr Whether or not "grow-from-root" samples should be included in cached predictions. Default FALSE. Ignored if num_mcmc = 0.
Expand Down Expand Up @@ -119,8 +120,8 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
q = 0.9, sigma2 = NULL, variable_weights = NULL, keep_vars_mu = NULL, drop_vars_mu = NULL,
keep_vars_tau = NULL, drop_vars_tau = NULL, num_trees_mu = 250, num_trees_tau = 50,
num_gfr = 5, num_burnin = 0, num_mcmc = 100, sample_sigma_global = T, sample_sigma_leaf_mu = T,
sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T,
b_0 = -0.5, b_1 = 0.5, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) {
sample_sigma_leaf_tau = F, propensity_covariate = "mu", adaptive_coding = T, b_0 = -0.5,
b_1 = 0.5, rfx_prior_var = NULL, random_seed = -1, keep_burnin = F, keep_gfr = F, verbose = F) {
# Variable weight preprocessing (and initialization if necessary)
if (is.null(variable_weights)) {
variable_weights = rep(1/ncol(X_train), ncol(X_train))
Expand Down Expand Up @@ -294,6 +295,16 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU
}
}
}

# Random effects covariance prior
if (has_rfx) {
if (is.null(rfx_prior_var)) {
rfx_prior_var <- rep(1, ncol(rfx_basis_train))
} else {
if ((!is.integer(rfx_prior_var)) && (!is.numeric(rfx_prior_var))) stop("rfx_prior_var must be a numeric vector")
if (length(rfx_prior_var) != ncol(rfx_basis_train)) stop("length(rfx_prior_var) must equal ncol(rfx_basis_train)")
}
}

# Update variable weights
variable_weights_adj <- 1/sapply(original_var_indices, function(x) sum(original_var_indices == x))
Expand Down Expand Up @@ -342,7 +353,10 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU

# Check whether treatment is binary (specifically 0-1 binary)
binary_treatment <- length(unique(Z_train)) == 2
if (!(all(sort(unique(Z_train)) == c(0,1)))) binary_treatment <- F
if (binary_treatment) {
unique_treatments <- sort(unique(Z_train))
if (!(all(unique_treatments == c(0,1)))) binary_treatment <- F
}

# Adaptive coding will be ignored for continuous / ordered categorical treatments
if ((!binary_treatment) && (adaptive_coding)) {
Expand Down Expand Up @@ -413,16 +427,22 @@ bcf <- function(X_train, Z_train, y_train, pi_train = NULL, group_ids_train = NU

# Random effects prior parameters
if (has_rfx) {
if (num_rfx_components == 1) {
alpha_init <- c(1)
} else if (num_rfx_components > 1) {
alpha_init <- c(1,rep(0,num_rfx_components-1))
} else {
# Initialize the working parameter to 1
if (num_rfx_components < 1) {
stop("There must be at least 1 random effect component")
}
xi_init <- matrix(rep(alpha_init, num_rfx_groups),num_rfx_components,num_rfx_groups)
alpha_init <- rep(1,num_rfx_components)
# Initialize each group parameter based on a regression of outcome on basis in that grou
xi_init <- matrix(0,num_rfx_components,num_rfx_groups)
for (i in 1:num_rfx_groups) {
group_subset_indices <- group_ids_train == i
basis_group <- rfx_basis_train[group_subset_indices,]
resid_group <- resid_train[group_subset_indices]
rfx_group_model <- lm(resid_group ~ 0+basis_group)
xi_init[,i] <- unname(coef(rfx_group_model))
}
sigma_alpha_init <- diag(1,num_rfx_components,num_rfx_components)
sigma_xi_init <- diag(1,num_rfx_components,num_rfx_components)
sigma_xi_init <- diag(rfx_prior_var)
sigma_xi_shape <- 1
sigma_xi_scale <- 1
}
Expand Down
9 changes: 5 additions & 4 deletions include/stochtree/random_effects.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class LabelMapper {
auto pos = label_map_.find(category_id);
return pos != label_map_.end();
}
bool CategoryNumber(int32_t category_id) {
int32_t CategoryNumber(int32_t category_id) {
return label_map_[category_id];
}
std::vector<int32_t>& Keys() {return keys_;}
Expand All @@ -99,7 +99,7 @@ class MultivariateRegressionRandomEffectsModel {
working_parameter_ = Eigen::VectorXd(num_components_);
group_parameters_ = Eigen::MatrixXd(num_components_, num_groups_);
group_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
working_parameter_covariance_ = Eigen::VectorXd(num_components_, num_components_);
working_parameter_covariance_ = Eigen::MatrixXd(num_components_, num_components_);
}
~MultivariateRegressionRandomEffectsModel() {}

Expand Down Expand Up @@ -206,7 +206,7 @@ class MultivariateRegressionRandomEffectsModel {
tracker.SetPrediction(i, new_pred);
}
}
private:

/*! \brief Compute the posterior mean of the working parameter, conditional on the group parameters and the variance components */
Eigen::VectorXd WorkingParameterMean(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance);
/*! \brief Compute the posterior covariance of the working parameter, conditional on the group parameters and the variance components */
Expand All @@ -219,7 +219,8 @@ class MultivariateRegressionRandomEffectsModel {
double VarianceComponentShape(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);
/*! \brief Compute the posterior scale of the group variance component, conditional on the working and group parameters */
double VarianceComponentScale(RandomEffectsDataset& dataset, ColumnVector& residual, RandomEffectsTracker& rfx_tracker, double global_variance, int32_t component_id);


private:
/*! \brief Samplers */
MultivariateNormalSampler normal_sampler_;
InverseGammaSampler ig_sampler_;
Expand Down
15 changes: 7 additions & 8 deletions src/random_effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ void MultivariateRegressionRandomEffectsModel::SampleRandomEffects(RandomEffects
AddCurrentPredictionToResidual(dataset, rfx_tracker, residual);

// Sample random effects
SampleWorkingParameter(dataset, residual, rfx_tracker, global_variance, gen);
SampleGroupParameters(dataset, residual, rfx_tracker, global_variance, gen);
SampleWorkingParameter(dataset, residual, rfx_tracker, global_variance, gen);
SampleVarianceComponents(dataset, residual, rfx_tracker, global_variance, gen);

// Update partial residual to remove the random effects
Expand Down Expand Up @@ -104,8 +104,8 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::WorkingParameterMean(R
X_group = X(observation_indices, Eigen::all);
y_group = y(observation_indices, Eigen::all);
xi_group = xi(Eigen::all, i);
posterior_denominator += (xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal();
posterior_numerator += (xi_group).asDiagonal() * X_group.transpose() * y_group;
posterior_denominator += ((xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal()) / global_variance;
posterior_numerator += (xi_group).asDiagonal() * X_group.transpose() * y_group / global_variance;
}
return posterior_denominator.inverse() * posterior_numerator;
}
Expand All @@ -127,8 +127,7 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::WorkingParameterVarian
X_group = X(observation_indices, Eigen::all);
y_group = y(observation_indices, Eigen::all);
xi_group = xi(Eigen::all, i);
posterior_denominator += (xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal();
posterior_numerator += (xi_group).asDiagonal() * X_group.transpose() * y_group;
posterior_denominator += ((xi_group).asDiagonal() * X_group.transpose() * X_group * (xi_group).asDiagonal()) / (global_variance);
}
return posterior_denominator.inverse();
}
Expand All @@ -144,8 +143,8 @@ Eigen::VectorXd MultivariateRegressionRandomEffectsModel::GroupParameterMean(Ran
std::vector<data_size_t> observation_indices = rfx_tracker.NodeIndicesInternalIndex(group_id);
Eigen::MatrixXd X_group = X(observation_indices, Eigen::all);
Eigen::VectorXd y_group = y(observation_indices, Eigen::all);
posterior_denominator += (alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal();
posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group;
posterior_denominator += ((alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal()) / (global_variance);
posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group / global_variance;
return posterior_denominator.inverse() * posterior_numerator;
}

Expand All @@ -160,7 +159,7 @@ Eigen::MatrixXd MultivariateRegressionRandomEffectsModel::GroupParameterVariance
std::vector<data_size_t> observation_indices = rfx_tracker.NodeIndicesInternalIndex(group_id);
Eigen::MatrixXd X_group = X(observation_indices, Eigen::all);
// Eigen::VectorXd y_group = y(observation_indices, Eigen::all);
posterior_denominator += (alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal();
posterior_denominator += ((alpha).asDiagonal() * X_group.transpose() * X_group * (alpha).asDiagonal()) / (global_variance);
// posterior_numerator += (alpha).asDiagonal() * X_group.transpose() * y_group;
return posterior_denominator.inverse();
}
Expand Down
57 changes: 57 additions & 0 deletions test/cpp/test_random_effects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,63 @@ TEST(RandomEffects, Construction) {
}
}

TEST(RandomEffects, Computation) {
// Load test data
StochTree::TestUtils::TestDataset test_dataset;
test_dataset = StochTree::TestUtils::LoadSmallRFXDatasetMultivariateBasis();
std::vector<StochTree::FeatureType> feature_types(test_dataset.x_cols, StochTree::FeatureType::kNumeric);

// Construct dataset
int n = test_dataset.n;
StochTree::ColumnVector residual = StochTree::ColumnVector(test_dataset.outcome.data(), test_dataset.n);
StochTree::RandomEffectsDataset dataset = StochTree::RandomEffectsDataset();
dataset.AddBasis(test_dataset.rfx_basis.data(), test_dataset.n, test_dataset.rfx_basis_cols, test_dataset.row_major);
dataset.AddGroupLabels(test_dataset.rfx_groups);

// Construct tracker, model state, and container
StochTree::RandomEffectsTracker tracker = StochTree::RandomEffectsTracker(test_dataset.rfx_groups);
StochTree::MultivariateRegressionRandomEffectsModel model = StochTree::MultivariateRegressionRandomEffectsModel(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
StochTree::RandomEffectsContainer container = StochTree::RandomEffectsContainer(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
StochTree::LabelMapper label_mapper = StochTree::LabelMapper(tracker.GetLabelMap());

// Set the values of alpha, xi and sigma in the model state (rather than simulating)
Eigen::VectorXd alpha(test_dataset.rfx_basis_cols);
Eigen::MatrixXd xi(test_dataset.rfx_basis_cols, test_dataset.rfx_num_groups);
Eigen::MatrixXd sigma(test_dataset.rfx_basis_cols, test_dataset.rfx_basis_cols);
alpha << 1., 1.;
xi << 1., 1., 1., 1., 1., 1.;
Eigen::VectorXd xi0 = xi(Eigen::all, 0);
Eigen::VectorXd xi1 = xi(Eigen::all, 1);
Eigen::VectorXd xi2 = xi(Eigen::all, 2);
sigma << 1, 0, 0, 1;
model.SetWorkingParameter(alpha);
model.SetGroupParameter(xi0, 0);
model.SetGroupParameter(xi1, 1);
model.SetGroupParameter(xi2, 2);
model.SetGroupParameterCovariance(sigma);
double sigma2 = 1.;

// Compute the posterior mean for the group parameters
Eigen::VectorXd xi0_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 0);
Eigen::VectorXd xi1_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 1);
Eigen::VectorXd xi2_mean = model.GroupParameterMean(dataset, residual, tracker, sigma2, 2);

// Check data in the container
std::vector<double> xi_mean_expected(test_dataset.rfx_basis_cols);
xi_mean_expected = {0.6979496, 0.3316027};
for (int i = 0; i < xi_mean_expected.size(); i++) {
ASSERT_NEAR(xi0_mean(i), xi_mean_expected[i], 0.001);
}
xi_mean_expected = {0.65744523, 0.00639347};
for (int i = 0; i < xi_mean_expected.size(); i++) {
ASSERT_NEAR(xi1_mean(i), xi_mean_expected[i], 0.001);
}
xi_mean_expected = {0.8763421, 0.3414047};
for (int i = 0; i < xi_mean_expected.size(); i++) {
ASSERT_NEAR(xi2_mean(i), xi_mean_expected[i], 0.001);
}
}

TEST(RandomEffects, Predict) {
// Load test data
StochTree::TestUtils::TestDataset test_dataset;
Expand Down
68 changes: 68 additions & 0 deletions test/cpp/testutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,74 @@ TestDataset LoadSmallDatasetMultivariateBasis() {
return output;
}

TestDataset LoadSmallRFXDatasetMultivariateBasis() {
TestDataset output;

// Data dimensions
output.n = 10;
output.x_cols = 5;
output.omega_cols = 2;
output.rfx_basis_cols = 2;
output.covariates.resize(output.n, output.x_cols);
output.omega.resize(output.n, output.omega_cols);
output.rfx_basis.resize(output.n, output.rfx_basis_cols);
output.rfx_groups.resize(output.n);
output.outcome.resize(output.n);

// Covariates
output.covariates << 0.766969853, 0.83894646, 0.63649772, 0.6747788934, 0.27398269,
0.634970996, 0.15237997, 0.3800786, 0.6457891271, 0.21604451,
0.229598754, 0.12461481, 0.81407372, 0.364336529, 0.45160373,
0.741084778, 0.53356288, 0.58940162, 0.9995219493, 0.19142269,
0.618177813, 0.88876378, 0.51174404, 0.8827708189, 0.12730742,
0.858657839, 0.9271676, 0.5115294, 0.67865624, 0.28658962,
0.719224842, 0.0546961, 0.42850897, 0.260336376, 0.1371501,
0.747422328, 0.87172033, 0.98791964, 0.4018020707, 0.29145664,
0.3158837, 0.39253551, 0.83610831, 0.0101785748, 0.1955386,
0.419554105, 0.5586495, 0.19908607, 0.4873921743, 0.35568569;

// Leaf regression basis
output.omega << 0.97801674, 0.3707159,
0.34045661, 0.1312134,
0.20528387, 0.5614470,
0.76230322, 0.2276504,
0.63244655, 0.9029984,
0.61225851, 0.7448547,
0.40492125, 0.2549813,
0.33112223, 0.5295535,
0.86917047, 0.5584614,
0.58444831, 0.2365117;

// Outcome
output.outcome << 2.158854445, 1.175387297, 0.40481061, 1.751578365, 0.299641379,
0.347249942, 0.546179903, 1.164750138, 3.389946886, -0.605464414;

// Random effects regression basis (i.e. constant, intercept-only RFX model)
output.rfx_basis << 1, 0.3707159,
1, 0.1312134,
1, 0.5614470,
1, 0.2276504,
1, 0.9029984,
1, 0.7448547,
1, 0.2549813,
1, 0.5295535,
1, 0.5584614,
1, 0.2365117;

// Random effects group labels
output.rfx_groups = {1,2,3,1,2,3,1,2,3,1};
// for (int i = 0; i < output.n; i++) {
// if (i % 2 == 0) {
// output.rfx_groups[i] = 1;
// } else {
// output.rfx_groups[i] = 2;
// }
// }
output.rfx_num_groups = 3;

return output;
}

TestDataset LoadMediumDatasetUnivariateBasis() {
TestDataset output;

Expand Down
3 changes: 3 additions & 0 deletions test/cpp/testutils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ TestDataset LoadSmallDatasetUnivariateBasis();
/*! Creates a small dataset (10 observations) with a multivariate basis for leaf regression applications */
TestDataset LoadSmallDatasetMultivariateBasis();

/*! Creates a small dataset (10 observations) with a multivariate basis and several random effects terms */
TestDataset LoadSmallRFXDatasetMultivariateBasis();

/*! Creates a modest dataset (100 observations) */
TestDataset LoadMediumDatasetUnivariateBasis();

Expand Down

0 comments on commit 63edb57

Please sign in to comment.