diff --git a/demo/notebooks/prototype_interface.ipynb b/demo/notebooks/prototype_interface.ipynb index 64d19808..79e0586e 100644 --- a/demo/notebooks/prototype_interface.ipynb +++ b/demo/notebooks/prototype_interface.ipynb @@ -678,7 +678,7 @@ "outputs": [], "source": [ "forest_pred_avg_gfr = np.squeeze(forest_preds_mu_gfr).mean(axis = 1, keepdims = True)\n", - "forest_pred_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_pred_avg_gfr), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n", + "forest_pred_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(mu_X,1), forest_pred_avg_gfr), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n", "sns.scatterplot(data=forest_pred_df_gfr, x=\"True mu\", y=\"Average estimated mu\")\n", "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", "plt.show()" @@ -734,7 +734,7 @@ "outputs": [], "source": [ "forest_pred_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis = 1, keepdims = True)\n", - "forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_pred_avg_mcmc), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n", + "forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(mu_X,1), forest_pred_avg_mcmc), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n", "sns.scatterplot(data=forest_pred_df_mcmc, x=\"True mu\", y=\"Average estimated mu\")\n", "plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n", "plt.show()" diff --git a/stochtree/__init__.py b/stochtree/__init__.py index da68f6fb..f4aa122b 100644 --- a/stochtree/__init__.py +++ b/stochtree/__init__.py @@ -1,5 +1,6 @@ from .bart import BARTModel from .bcf import BCFModel +from .calibration import calibrate_global_error_variance from .data import Dataset, Residual from .forest import ForestContainer from .preprocessing import CovariateTransformer @@ -9,4 +10,4 @@ __all__ = ['BARTModel', 'BCFModel', 'Dataset', 'Residual', 'ForestContainer', 'CovariateTransformer', 'RNG', 'ForestSampler', 'GlobalVarianceModel', - 'LeafVarianceModel', 'JSONSerializer', 'NotSampledError'] \ No newline at end of file + 'LeafVarianceModel', 'JSONSerializer', 'NotSampledError', 'calibrate_global_error_variance'] \ No newline at end of file diff --git a/stochtree/bart.py b/stochtree/bart.py index e6a35739..df88fb01 100644 --- a/stochtree/bart.py +++ b/stochtree/bart.py @@ -3,8 +3,10 @@ """ import numpy as np import pandas as pd -from scipy.linalg import lstsq +from sklearn import linear_model +from sklearn.metrics import mean_squared_error from scipy.stats import gamma +from .calibration import calibrate_global_error_variance from .data import Dataset, Residual from .forest import ForestContainer from .preprocessing import CovariateTransformer @@ -173,14 +175,11 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N self.y_std = np.squeeze(np.std(y_train)) resid_train = (y_train-self.y_bar)/self.y_std - # Calibrate priors for global sigma^2 and sigma_leaf - if lamb is None: - reg_basis = np.c_[np.ones(self.n_train),X_train_processed] - reg_soln = lstsq(reg_basis, np.squeeze(resid_train)) - sigma2hat = reg_soln[1] / self.n_train - quantile_cutoff = q - lamb = (sigma2hat*gamma.ppf(1-quantile_cutoff,nu))/nu - sigma2 = sigma2hat if sigma2 is None else sigma2 + # Calibrate priors for global sigma^2 and sigma_leaf (don't use regression initializer for warm-start or XBART) + if num_gfr > 0: + sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False) + else: + sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True) b_leaf = np.squeeze(np.var(resid_train)) / num_trees if b_leaf is None else b_leaf sigma_leaf = np.squeeze(np.var(resid_train)) / num_trees if sigma_leaf is None else sigma_leaf current_sigma2 = sigma2 diff --git a/stochtree/bcf.py b/stochtree/bcf.py index b2be94db..96f3ae84 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -10,6 +10,7 @@ from scipy.linalg import lstsq from scipy.stats import gamma from .bart import BARTModel +from .calibration import calibrate_global_error_variance from .data import Dataset, Residual from .forest import ForestContainer from .preprocessing import CovariateTransformer @@ -510,14 +511,11 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr self.y_std = np.squeeze(np.std(y_train)) resid_train = (y_train-self.y_bar)/self.y_std - # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau - if lamb is None: - reg_basis = np.c_[np.ones(self.n_train),X_train_processed] - reg_soln = lstsq(reg_basis, np.squeeze(resid_train)) - sigma2hat = reg_soln[1] / self.n_train - quantile_cutoff = q - lamb = (sigma2hat*gamma.ppf(1-quantile_cutoff,nu))/nu - sigma2 = sigma2hat if sigma2 is None else sigma2 + # Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART) + if num_gfr > 0: + sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False) + else: + sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True) b_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if b_leaf_mu is None else b_leaf_mu b_leaf_tau = np.squeeze(np.var(resid_train)) / (2*num_trees_tau) if b_leaf_tau is None else b_leaf_tau sigma_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if sigma_leaf_mu is None else sigma_leaf_mu diff --git a/stochtree/calibration.py b/stochtree/calibration.py new file mode 100644 index 00000000..20a3a27f --- /dev/null +++ b/stochtree/calibration.py @@ -0,0 +1,96 @@ +import warnings +import numpy as np +import pandas as pd +from sklearn import linear_model +from sklearn.metrics import mean_squared_error +from scipy.stats import gamma + + +def calibrate_global_error_variance(X: np.array, y: np.array, sigma2: float = None, nu: float = 3, lamb: float = None, q: float = 0.9, lm_calibrate: bool = True) -> tuple: + """Calibrates global error variance model by setting an initial value of sigma^2 (the parameter itself) and setting a value of lambda, part of the scale parameter in the + ``sigma2 ~ IG(nu/2, (nu*lambda)/2)`` prior. + + Parameters + ---------- + X : :obj:`np.array` + Covariates to be used as split candidates for constructing trees. + y : :obj:`np.array` + Outcome to be used as target for constructing trees. + sigma2 : :obj:`float`, optional + Starting value of global variance parameter. Calibrated internally if not set here. + nu : :obj:`float`, optional + Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``. + lamb : :obj:`float`, optional + Component of the scale parameter in the ``IG(nu, nu*lambda)`` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021). + q : :obj:`float`, optional + Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``. + lm_calibrate : :obj:`bool`, optional + Whether or not to calibrate sigma2 based on a linear model of ``y`` given ``X``. If ``True``, uses the linear model calibration technique in Sparapani et al (2021), otherwise uses `np.var(y)`. Defaults to ``True``. + + Returns + ------- + (sigma2, lamb) : :obj:`tuple` of :obj:`float` + Tuple containing an initial value of sigma^2 (global error variance) and lambda (part of scale parameter of global error variance model) + """ + # Initialize sigma if no initial value is provided + var_y = np.var(y) + if sigma2 is None: + if lm_calibrate: + # Convert X and y to expected dimensions + if X.ndim == 2: + X_processed = X + elif X.ndim == 1: + X_processed = np.expand_dims(X, 1) + else: + raise ValueError("X must be a 1 or 2 dimensional numpy array") + n, p = X_processed.shape + + if y.ndim == 2: + y_processed = np.squeeze(y) + elif y.ndim == 1: + y_processed = y + else: + raise ValueError("y must be a 1 or 2 dimensional numpy array") + + # Fit a linear model of y ~ X + lm_calibrator = linear_model.LinearRegression() + lm_calibrator.fit(X_processed, y_processed) + + # Compute MSE + y_hat_processed = lm_calibrator.predict(X_processed) + mse = mean_squared_error(y_processed, y_hat_processed) + + # Check for overdetermination, revert to variance of y if model is overdetermined + eps = np.finfo("double").eps + if _is_model_overdetermined(lm_calibrator, n, mse, eps): + sigma2 = var_y + warnings.warn("Default calibration method for global error variance failed; covariate dimension exceeds number of samples. " + "Initializing global error variance based on the variance of the standardized outcome.", UserWarning) + else: + sigma2 = mse + if _is_model_rank_deficient(lm_calibrator, p): + warnings.warn("Default calibration method for global error variance detected rank deficiency in covariate matrix. " + "This should not impact the calibrated values, but may indicate the presence of duplicated covariates.", UserWarning) + else: + sigma2 = var_y + + # Calibrate lamb if no initial value is provided + if lamb is None: + lamb = (sigma2*gamma.ppf(1-q,nu))/nu + + return (sigma2, lamb) + +def _is_model_overdetermined(reg_model: linear_model.LinearRegression, n: int, mse: float, eps: float) -> bool: + + if reg_model.rank_ == n: + return True + elif np.abs(mse) < eps: + return True + else: + return False + +def _is_model_rank_deficient(reg_model: linear_model.LinearRegression, p: int) -> bool: + if reg_model.rank_ < p: + return True + else: + return False diff --git a/test/python/test_calibration.py b/test/python/test_calibration.py new file mode 100644 index 00000000..99292cce --- /dev/null +++ b/test/python/test_calibration.py @@ -0,0 +1,57 @@ +import numpy as np +import pandas as pd +from sklearn import linear_model +from sklearn.metrics import mean_squared_error +from scipy.stats import gamma +from stochtree import CovariateTransformer +from stochtree import calibrate_global_error_variance +import pytest + +class TestCalibration: + def test_full_rank(self): + n = 100 + p = 5 + nu = 3 + q = 0.9 + X = np.random.uniform(size=(n,p)) + y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n) + reg_model = linear_model.LinearRegression() + reg_model.fit(X, y) + mse = mean_squared_error(y, reg_model.predict(X)) + sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) + assert sigma2 == pytest.approx(mse) + assert lamb == pytest.approx((mse*gamma.ppf(1-q,nu))/nu) + + def test_rank_deficient(self): + n = 100 + p = 5 + nu = 3 + q = 0.9 + X = np.random.uniform(size=(n,p)) + X[:,4] = X[:,2] + y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n) + reg_model = linear_model.LinearRegression() + reg_model.fit(X, y) + mse = mean_squared_error(y, reg_model.predict(X)) + if reg_model.rank_ < p: + with pytest.warns(UserWarning): + sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) + else: + sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) + assert sigma2 == pytest.approx(mse) + assert lamb == pytest.approx((mse*gamma.ppf(1-q,nu))/nu) + + def test_overdetermined(self): + n = 100 + p = 101 + nu = 3 + q = 0.9 + X = np.random.uniform(size=(n,p)) + y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n) + reg_model = linear_model.LinearRegression() + reg_model.fit(X, y) + mse = mean_squared_error(y, reg_model.predict(X)) + with pytest.warns(UserWarning): + sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True) + assert sigma2 == pytest.approx(np.var(y)) + assert lamb == pytest.approx(np.var(y)*(gamma.ppf(1-q,nu))/nu)