Skip to content

Commit

Permalink
[MRG+2] Adding return_std options for models in linear_model/bayes.py (
Browse files Browse the repository at this point in the history
…scikit-learn#7838)

* initial commit for return_std

* initial commit for return_std

* adding tests, examples, ARD predict_std

* adding tests, examples, ARD predict_std

* a smidge more documentation

* a smidge more documentation

* Missed a few PEP8 issues

* Changing predict_std to return_std scikit-learn#1

* Changing predict_std to return_std scikit-learn#2

* Changing predict_std to return_std scikit-learn#3

* Changing predict_std to return_std final

* adding better plots via polynomial regression

* trying to fix flake error

* fix to ARD plotting issue

* fixing some flakes

* Two blank lines part 1

* Two blank lines part 2

* More newlines!

* Even more newlines

* adding info to the doc string for the two plot files

* Rephrasing "polynomial" for Bayesian Ridge Regression

* Updating "polynomia" for ARD

* Adding more formal references

* Another asked-for improvement to doc string.

* Fixing flake8 errors

* Cleaning up the tests a smidge.

* A few more flakes

* requested fixes from Andy

* Mini bug fix

* Final pep8 fix

* pep8 fix round 2

* Fix beta_ to alpha_ in the comments
  • Loading branch information
sergeyf authored and Sundrique committed Jun 14, 2017
1 parent 89b3fcf commit 9feedef
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 9 deletions.
36 changes: 34 additions & 2 deletions examples/linear_model/plot_ard.py
Expand Up @@ -15,6 +15,12 @@
The estimation of the model is done by iteratively maximizing the
marginal log-likelihood of the observations.
We also plot predictions and uncertainties for ARD
for one dimensional regression using polynomial feature expansion.
Note the uncertainty starts going up on the right side of the plot.
This is because these test samples are outside of the range of the training
samples.
"""
print(__doc__)

Expand Down Expand Up @@ -54,8 +60,8 @@
ols.fit(X, y)

###############################################################################
# Plot the true weights, the estimated weights and the histogram of the
# weights
# Plot the true weights, the estimated weights, the histogram of the
# weights, and predictions with standard deviations
plt.figure(figsize=(6, 5))
plt.title("Weights of the model")
plt.plot(clf.coef_, color='darkblue', linestyle='-', linewidth=2,
Expand All @@ -81,4 +87,30 @@
plt.plot(clf.scores_, color='navy', linewidth=2)
plt.ylabel("Score")
plt.xlabel("Iterations")


# Plotting some predictions for polynomial regression
def f(x, noise_amount):
y = np.sqrt(x) * np.sin(x)
noise = np.random.normal(0, 1, len(x))
return y + noise_amount * noise


degree = 10
X = np.linspace(0, 10, 100)
y = f(X, noise_amount=1)
clf_poly = ARDRegression(threshold_lambda=1e5)
clf_poly.fit(np.vander(X, degree), y)

X_plot = np.linspace(0, 11, 25)
y_plot = f(X_plot, noise_amount=0)
y_mean, y_std = clf_poly.predict(np.vander(X_plot, degree), return_std=True)
plt.figure(figsize=(6, 5))
plt.errorbar(X_plot, y_mean, y_std, color='navy',
label="Polynomial ARD", linewidth=2)
plt.plot(X_plot, y_plot, color='gold', linewidth=2,
label="Ground Truth")
plt.ylabel("Output y")
plt.xlabel("Feature X")
plt.legend(loc="lower left")
plt.show()
35 changes: 34 additions & 1 deletion examples/linear_model/plot_bayesian_ridge.py
Expand Up @@ -15,6 +15,12 @@
The estimation of the model is done by iteratively maximizing the
marginal log-likelihood of the observations.
We also plot predictions and uncertainties for Bayesian Ridge Regression
for one dimensional regression using polynomial feature expansion.
Note the uncertainty starts going up on the right side of the plot.
This is because these test samples are outside of the range of the training
samples.
"""
print(__doc__)

Expand Down Expand Up @@ -51,7 +57,8 @@
ols.fit(X, y)

###############################################################################
# Plot true weights, estimated weights and histogram of the weights
# Plot true weights, estimated weights, histogram of the weights, and
# predictions with standard deviations
lw = 2
plt.figure(figsize=(6, 5))
plt.title("Weights of the model")
Expand All @@ -77,4 +84,30 @@
plt.plot(clf.scores_, color='navy', linewidth=lw)
plt.ylabel("Score")
plt.xlabel("Iterations")


# Plotting some predictions for polynomial regression
def f(x, noise_amount):
y = np.sqrt(x) * np.sin(x)
noise = np.random.normal(0, 1, len(x))
return y + noise_amount * noise


degree = 10
X = np.linspace(0, 10, 100)
y = f(X, noise_amount=0.1)
clf_poly = BayesianRidge()
clf_poly.fit(np.vander(X, degree), y)

X_plot = np.linspace(0, 11, 25)
y_plot = f(X_plot, noise_amount=0)
y_mean, y_std = clf_poly.predict(np.vander(X_plot, degree), return_std=True)
plt.figure(figsize=(6, 5))
plt.errorbar(X_plot, y_mean, y_std, color='navy',
label="Polynomial Bayesian Ridge Regression", linewidth=lw)
plt.plot(X_plot, y_plot, color='gold', linewidth=lw,
label="Ground Truth")
plt.ylabel("Output y")
plt.xlabel("Feature X")
plt.legend(loc="lower left")
plt.show()
109 changes: 103 additions & 6 deletions sklearn/linear_model/bayes.py
Expand Up @@ -91,6 +91,9 @@ class BayesianRidge(LinearModel, RegressorMixin):
lambda_ : float
estimated precision of the weights.
sigma_ : array, shape = (n_features, n_features)
estimated variance-covariance matrix of the weights
scores_ : float
if computed, value of the objective function (to be maximized)
Expand All @@ -109,6 +112,16 @@ class BayesianRidge(LinearModel, RegressorMixin):
Notes
-----
See examples/linear_model/plot_bayesian_ridge.py for an example.
References
----------
D. J. C. MacKay, Bayesian Interpolation, Computation and Neural Systems,
Vol. 4, No. 3, 1992.
R. Salakhutdinov, Lecture notes on Statistical Machine Learning,
http://www.utstat.toronto.edu/~rsalakhu/sta4273/notes/Lecture2.pdf#page=15
Their beta is our self.alpha_
Their alpha is our self.lambda_
"""

def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6,
Expand Down Expand Up @@ -142,8 +155,10 @@ def fit(self, X, y):
self : returns an instance of self.
"""
X, y = check_X_y(X, y, dtype=np.float64, y_numeric=True)
X, y, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, X_offset_, y_offset_, X_scale_ = self._preprocess_data(
X, y, self.fit_intercept, self.normalize, self.copy_X)
self.X_offset_ = X_offset_
self.X_scale_ = X_scale_
n_samples, n_features = X.shape

# Initialization of the values of the parameters
Expand Down Expand Up @@ -171,7 +186,8 @@ def fit(self, X, y):
# coef_ = sigma_^-1 * XT * y
if n_samples > n_features:
coef_ = np.dot(Vh.T,
Vh / (eigen_vals_ + lambda_ / alpha_)[:, None])
Vh / (eigen_vals_ +
lambda_ / alpha_)[:, np.newaxis])
coef_ = np.dot(coef_, XT_y)
if self.compute_score:
logdet_sigma_ = - np.sum(
Expand Down Expand Up @@ -216,10 +232,45 @@ def fit(self, X, y):
self.alpha_ = alpha_
self.lambda_ = lambda_
self.coef_ = coef_
sigma_ = np.dot(Vh.T,
Vh / (eigen_vals_ + lambda_ / alpha_)[:, np.newaxis])
self.sigma_ = (1. / alpha_) * sigma_

self._set_intercept(X_offset, y_offset, X_scale)
self._set_intercept(X_offset_, y_offset_, X_scale_)
return self

def predict(self, X, return_std=False):
"""Predict using the linear model.
In addition to the mean of the predictive distribution, also its
standard deviation can be returned.
Parameters
----------
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
Samples.
return_std : boolean, optional
Whether to return the standard deviation of posterior prediction.
Returns
-------
y_mean : array, shape = (n_samples,)
Mean of predictive distribution of query points.
y_std : array, shape = (n_samples,)
Standard deviation of predictive distribution of query points.
"""
y_mean = self._decision_function(X)
if return_std is False:
return y_mean
else:
if self.normalize:
X = (X - self.X_offset_) / self.X_scale_
sigmas_squared_data = (np.dot(X, self.sigma_) * X).sum(axis=1)
y_std = np.sqrt(sigmas_squared_data + (1. / self.alpha_))
return y_mean, y_std


###############################################################################
# ARD (Automatic Relevance Determination) regression
Expand Down Expand Up @@ -323,6 +374,19 @@ class ARDRegression(LinearModel, RegressorMixin):
Notes
--------
See examples/linear_model/plot_ard.py for an example.
References
----------
D. J. C. MacKay, Bayesian nonlinear modeling for the prediction
competition, ASHRAE Transactions, 1994.
R. Salakhutdinov, Lecture notes on Statistical Machine Learning,
http://www.utstat.toronto.edu/~rsalakhu/sta4273/notes/Lecture2.pdf#page=15
Their beta is our self.alpha_
Their alpha is our self.lambda_
ARD is a little different than the slide: only dimensions/features for
which self.lambda_ < self.threshold_lambda are kept and the rest are
discarded.
"""

def __init__(self, n_iter=300, tol=1.e-3, alpha_1=1.e-6, alpha_2=1.e-6,
Expand Down Expand Up @@ -365,7 +429,7 @@ def fit(self, X, y):
n_samples, n_features = X.shape
coef_ = np.zeros(n_features)

X, y, X_offset, y_offset, X_scale = self._preprocess_data(
X, y, X_offset_, y_offset_, X_scale_ = self._preprocess_data(
X, y, self.fit_intercept, self.normalize, self.copy_X)

# Launch the convergence loop
Expand Down Expand Up @@ -417,7 +481,7 @@ def fit(self, X, y):
s = (lambda_1 * np.log(lambda_) - lambda_2 * lambda_).sum()
s += alpha_1 * log(alpha_) - alpha_2 * alpha_
s += 0.5 * (fast_logdet(sigma_) + n_samples * log(alpha_) +
np.sum(np.log(lambda_)))
np.sum(np.log(lambda_)))
s -= 0.5 * (alpha_ * rmse_ + (lambda_ * coef_ ** 2).sum())
self.scores_.append(s)

Expand All @@ -432,5 +496,38 @@ def fit(self, X, y):
self.alpha_ = alpha_
self.sigma_ = sigma_
self.lambda_ = lambda_
self._set_intercept(X_offset, y_offset, X_scale)
self._set_intercept(X_offset_, y_offset_, X_scale_)
return self

def predict(self, X, return_std=False):
"""Predict using the linear model.
In addition to the mean of the predictive distribution, also its
standard deviation can be returned.
Parameters
----------
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
Samples.
return_std : boolean, optional
Whether to return the standard deviation of posterior prediction.
Returns
-------
y_mean : array, shape = (n_samples,)
Mean of predictive distribution of query points.
y_std : array, shape = (n_samples,)
Standard deviation of predictive distribution of query points.
"""
y_mean = self._decision_function(X)
if return_std is False:
return y_mean
else:
if self.normalize:
X = (X - self.X_offset_) / self.X_scale_
X = X[:, self.lambda_ < self.threshold_lambda]
sigmas_squared_data = (np.dot(X, self.sigma_) * X).sum(axis=1)
y_std = np.sqrt(sigmas_squared_data + (1. / self.alpha_))
return y_mean, y_std
32 changes: 32 additions & 0 deletions sklearn/linear_model/tests/test_bayes.py
Expand Up @@ -56,3 +56,35 @@ def test_toy_ard_object():
# Check that the model could approximately learn the identity function
test = [[1], [3], [4]]
assert_array_almost_equal(clf.predict(test), [1, 3, 4], 2)


def test_return_std():
# Test return_std option for both Bayesian regressors
def f(X):
return np.dot(X, w) + b

def f_noise(X, noise_mult):
return f(X) + np.random.randn(X.shape[0])*noise_mult

d = 5
n_train = 50
n_test = 10

w = np.array([1.0, 0.0, 1.0, -1.0, 0.0])
b = 1.0

X = np.random.random((n_train, d))
X_test = np.random.random((n_test, d))

for decimal, noise_mult in enumerate([1, 0.1, 0.01]):
y = f_noise(X, noise_mult)

m1 = BayesianRidge()
m1.fit(X, y)
y_mean1, y_std1 = m1.predict(X_test, return_std=True)
assert_array_almost_equal(y_std1, noise_mult, decimal=decimal)

m2 = ARDRegression()
m2.fit(X, y)
y_mean2, y_std2 = m2.predict(X_test, return_std=True)
assert_array_almost_equal(y_std2, noise_mult, decimal=decimal)

0 comments on commit 9feedef

Please sign in to comment.