Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions demo/notebooks/prototype_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@
"outputs": [],
"source": [
"forest_pred_avg_gfr = np.squeeze(forest_preds_mu_gfr).mean(axis = 1, keepdims = True)\n",
"forest_pred_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_pred_avg_gfr), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n",
"forest_pred_df_gfr = pd.DataFrame(np.concatenate((np.expand_dims(mu_X,1), forest_pred_avg_gfr), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n",
"sns.scatterplot(data=forest_pred_df_gfr, x=\"True mu\", y=\"Average estimated mu\")\n",
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
"plt.show()"
Expand Down Expand Up @@ -734,7 +734,7 @@
"outputs": [],
"source": [
"forest_pred_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis = 1, keepdims = True)\n",
"forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_X,1), forest_pred_avg_mcmc), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n",
"forest_pred_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(mu_X,1), forest_pred_avg_mcmc), axis = 1), columns=[\"True mu\", \"Average estimated mu\"])\n",
"sns.scatterplot(data=forest_pred_df_mcmc, x=\"True mu\", y=\"Average estimated mu\")\n",
"plt.axline((0, 0), slope=1, color=\"black\", linestyle=(0, (3,3)))\n",
"plt.show()"
Expand Down
3 changes: 2 additions & 1 deletion stochtree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .bart import BARTModel
from .bcf import BCFModel
from .calibration import calibrate_global_error_variance
from .data import Dataset, Residual
from .forest import ForestContainer
from .preprocessing import CovariateTransformer
Expand All @@ -9,4 +10,4 @@

__all__ = ['BARTModel', 'BCFModel', 'Dataset', 'Residual', 'ForestContainer',
'CovariateTransformer', 'RNG', 'ForestSampler', 'GlobalVarianceModel',
'LeafVarianceModel', 'JSONSerializer', 'NotSampledError']
'LeafVarianceModel', 'JSONSerializer', 'NotSampledError', 'calibrate_global_error_variance']
17 changes: 8 additions & 9 deletions stochtree/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
"""
import numpy as np
import pandas as pd
from scipy.linalg import lstsq
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from scipy.stats import gamma
from .calibration import calibrate_global_error_variance
from .data import Dataset, Residual
from .forest import ForestContainer
from .preprocessing import CovariateTransformer
Expand Down Expand Up @@ -173,14 +175,11 @@ def sample(self, X_train: np.array, y_train: np.array, basis_train: np.array = N
self.y_std = np.squeeze(np.std(y_train))
resid_train = (y_train-self.y_bar)/self.y_std

# Calibrate priors for global sigma^2 and sigma_leaf
if lamb is None:
reg_basis = np.c_[np.ones(self.n_train),X_train_processed]
reg_soln = lstsq(reg_basis, np.squeeze(resid_train))
sigma2hat = reg_soln[1] / self.n_train
quantile_cutoff = q
lamb = (sigma2hat*gamma.ppf(1-quantile_cutoff,nu))/nu
sigma2 = sigma2hat if sigma2 is None else sigma2
# Calibrate priors for global sigma^2 and sigma_leaf (don't use regression initializer for warm-start or XBART)
if num_gfr > 0:
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False)
else:
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True)
b_leaf = np.squeeze(np.var(resid_train)) / num_trees if b_leaf is None else b_leaf
sigma_leaf = np.squeeze(np.var(resid_train)) / num_trees if sigma_leaf is None else sigma_leaf
current_sigma2 = sigma2
Expand Down
14 changes: 6 additions & 8 deletions stochtree/bcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.linalg import lstsq
from scipy.stats import gamma
from .bart import BARTModel
from .calibration import calibrate_global_error_variance
from .data import Dataset, Residual
from .forest import ForestContainer
from .preprocessing import CovariateTransformer
Expand Down Expand Up @@ -510,14 +511,11 @@ def sample(self, X_train: Union[pd.DataFrame, np.array], Z_train: np.array, y_tr
self.y_std = np.squeeze(np.std(y_train))
resid_train = (y_train-self.y_bar)/self.y_std

# Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau
if lamb is None:
reg_basis = np.c_[np.ones(self.n_train),X_train_processed]
reg_soln = lstsq(reg_basis, np.squeeze(resid_train))
sigma2hat = reg_soln[1] / self.n_train
quantile_cutoff = q
lamb = (sigma2hat*gamma.ppf(1-quantile_cutoff,nu))/nu
sigma2 = sigma2hat if sigma2 is None else sigma2
# Calibrate priors for global sigma^2 and sigma_leaf_mu / sigma_leaf_tau (don't use regression initializer for warm-start or XBART)
if num_gfr > 0:
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, False)
else:
sigma2, lamb = calibrate_global_error_variance(X_train_processed, np.squeeze(resid_train), sigma2, nu, lamb, q, True)
b_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if b_leaf_mu is None else b_leaf_mu
b_leaf_tau = np.squeeze(np.var(resid_train)) / (2*num_trees_tau) if b_leaf_tau is None else b_leaf_tau
sigma_leaf_mu = np.squeeze(np.var(resid_train)) / num_trees_mu if sigma_leaf_mu is None else sigma_leaf_mu
Expand Down
96 changes: 96 additions & 0 deletions stochtree/calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import warnings
import numpy as np
import pandas as pd
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from scipy.stats import gamma


def calibrate_global_error_variance(X: np.array, y: np.array, sigma2: float = None, nu: float = 3, lamb: float = None, q: float = 0.9, lm_calibrate: bool = True) -> tuple:
"""Calibrates global error variance model by setting an initial value of sigma^2 (the parameter itself) and setting a value of lambda, part of the scale parameter in the
``sigma2 ~ IG(nu/2, (nu*lambda)/2)`` prior.

Parameters
----------
X : :obj:`np.array`
Covariates to be used as split candidates for constructing trees.
y : :obj:`np.array`
Outcome to be used as target for constructing trees.
sigma2 : :obj:`float`, optional
Starting value of global variance parameter. Calibrated internally if not set here.
nu : :obj:`float`, optional
Shape parameter in the ``IG(nu, nu*lamb)`` global error variance model. Defaults to ``3``.
lamb : :obj:`float`, optional
Component of the scale parameter in the ``IG(nu, nu*lambda)`` global error variance prior. If not specified, this is calibrated as in Sparapani et al (2021).
q : :obj:`float`, optional
Quantile used to calibrated ``lamb`` as in Sparapani et al (2021). Defaults to ``0.9``.
lm_calibrate : :obj:`bool`, optional
Whether or not to calibrate sigma2 based on a linear model of ``y`` given ``X``. If ``True``, uses the linear model calibration technique in Sparapani et al (2021), otherwise uses `np.var(y)`. Defaults to ``True``.

Returns
-------
(sigma2, lamb) : :obj:`tuple` of :obj:`float`
Tuple containing an initial value of sigma^2 (global error variance) and lambda (part of scale parameter of global error variance model)
"""
# Initialize sigma if no initial value is provided
var_y = np.var(y)
if sigma2 is None:
if lm_calibrate:
# Convert X and y to expected dimensions
if X.ndim == 2:
X_processed = X
elif X.ndim == 1:
X_processed = np.expand_dims(X, 1)
else:
raise ValueError("X must be a 1 or 2 dimensional numpy array")
n, p = X_processed.shape

if y.ndim == 2:
y_processed = np.squeeze(y)
elif y.ndim == 1:
y_processed = y
else:
raise ValueError("y must be a 1 or 2 dimensional numpy array")

# Fit a linear model of y ~ X
lm_calibrator = linear_model.LinearRegression()
lm_calibrator.fit(X_processed, y_processed)

# Compute MSE
y_hat_processed = lm_calibrator.predict(X_processed)
mse = mean_squared_error(y_processed, y_hat_processed)

# Check for overdetermination, revert to variance of y if model is overdetermined
eps = np.finfo("double").eps
if _is_model_overdetermined(lm_calibrator, n, mse, eps):
sigma2 = var_y
warnings.warn("Default calibration method for global error variance failed; covariate dimension exceeds number of samples. "
"Initializing global error variance based on the variance of the standardized outcome.", UserWarning)
else:
sigma2 = mse
if _is_model_rank_deficient(lm_calibrator, p):
warnings.warn("Default calibration method for global error variance detected rank deficiency in covariate matrix. "
"This should not impact the calibrated values, but may indicate the presence of duplicated covariates.", UserWarning)
else:
sigma2 = var_y

# Calibrate lamb if no initial value is provided
if lamb is None:
lamb = (sigma2*gamma.ppf(1-q,nu))/nu

return (sigma2, lamb)

def _is_model_overdetermined(reg_model: linear_model.LinearRegression, n: int, mse: float, eps: float) -> bool:

if reg_model.rank_ == n:
return True
elif np.abs(mse) < eps:
return True
else:
return False

def _is_model_rank_deficient(reg_model: linear_model.LinearRegression, p: int) -> bool:
if reg_model.rank_ < p:
return True
else:
return False
57 changes: 57 additions & 0 deletions test/python/test_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import pandas as pd
from sklearn import linear_model
from sklearn.metrics import mean_squared_error
from scipy.stats import gamma
from stochtree import CovariateTransformer
from stochtree import calibrate_global_error_variance
import pytest

class TestCalibration:
def test_full_rank(self):
n = 100
p = 5
nu = 3
q = 0.9
X = np.random.uniform(size=(n,p))
y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n)
reg_model = linear_model.LinearRegression()
reg_model.fit(X, y)
mse = mean_squared_error(y, reg_model.predict(X))
sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True)
assert sigma2 == pytest.approx(mse)
assert lamb == pytest.approx((mse*gamma.ppf(1-q,nu))/nu)

def test_rank_deficient(self):
n = 100
p = 5
nu = 3
q = 0.9
X = np.random.uniform(size=(n,p))
X[:,4] = X[:,2]
y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n)
reg_model = linear_model.LinearRegression()
reg_model.fit(X, y)
mse = mean_squared_error(y, reg_model.predict(X))
if reg_model.rank_ < p:
with pytest.warns(UserWarning):
sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True)
else:
sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True)
assert sigma2 == pytest.approx(mse)
assert lamb == pytest.approx((mse*gamma.ppf(1-q,nu))/nu)

def test_overdetermined(self):
n = 100
p = 101
nu = 3
q = 0.9
X = np.random.uniform(size=(n,p))
y = 1 + X[:,0]*0.1 - X[:,1]*0.2 + np.random.normal(size=n)
reg_model = linear_model.LinearRegression()
reg_model.fit(X, y)
mse = mean_squared_error(y, reg_model.predict(X))
with pytest.warns(UserWarning):
sigma2, lamb = calibrate_global_error_variance(X, y, None, nu, None, q, True)
assert sigma2 == pytest.approx(np.var(y))
assert lamb == pytest.approx(np.var(y)*(gamma.ppf(1-q,nu))/nu)