Skip to content

Commit

Permalink
Do away with intercept helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
MechCoder committed Jul 22, 2014
1 parent b088c2a commit 137d40b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 117 deletions.
175 changes: 64 additions & 111 deletions sklearn/linear_model/logistic.py
Expand Up @@ -20,6 +20,7 @@
from ..preprocessing import LabelEncoder, LabelBinarizer
from ..svm.base import BaseLibLinear
from ..utils import as_float_array, check_arrays
from ..utils.extmath import log_logistic, safe_sparse_dot
from ..externals.joblib import Parallel, delayed
from ..cross_validation import check_cv
from ..utils.optimize import newton_cg
Expand All @@ -30,140 +31,101 @@
# .. some helper functions for logistic_regression_path ..
def _logistic_loss_and_grad(w, X, y, alpha):
# the logistic loss and its gradient
z = X.dot(w)
fit_intercept = False
c = 0
_, n_features = X.shape
grad = np.empty_like(w)

# the fit_intercept case
if w.size == n_features + 1:
fit_intercept = True
c = w[-1]
w = w[:-1]

z = safe_sparse_dot(X, w)
z += c
yz = y * z
out = np.empty_like(yz)
idx = yz > 0
out[idx] = np.log(1 + np.exp(-yz[idx]))
out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx])))
out = out.sum() + .5 * alpha * w.dot(w)

# Logistic loss is the negative of the log of the logistic function.
out = -np.sum(log_logistic(yz)) + .5 * alpha * np.dot(w, w)

z = special.expit(yz)
z0 = (z - 1) * y
grad = X.T.dot(z0) + alpha * w

grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w
if fit_intercept:
grad[-1] = z0.sum()
return out, grad


def _logistic_loss(w, X, y, alpha):

# For the fit_intercept case.
c = 0
if w.size == X.shape[1] + 1:
c = w[-1]
w = w[:-1]

# the logistic loss and
z = X.dot(w)
z = safe_sparse_dot(X, w)
z += c
yz = y * z
out = np.empty_like(yz)
idx = yz > 0
out[idx] = np.log(1 + np.exp(-yz[idx]))
out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx])))
out = out.sum() + .5 * alpha * w.dot(w)
#print 'Loss %r' % out

# Logistic loss is the negative of the log of the logistic function.
out = -np.sum(log_logistic(yz)) + .5 * alpha * np.dot(w, w)
return out


def _logistic_loss_grad_hess(w, X, y, alpha):
# the logistic loss, its gradient, and the matvec application of the
# Hessian
z = X.dot(w)
yz = y * z
out = np.empty_like(yz)
idx = yz > 0
out[idx] = np.log(1 + np.exp(-yz[idx]))
out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx])))
out = out.sum() + .5 * alpha * w.dot(w)

z = special.expit(yz)
z0 = (z - 1) * y
grad = X.T.dot(z0) + alpha * w

# The mat-vec product of the Hessian
d = z * (1 - z)
d = np.sqrt(d, out=d)
if sparse.issparse(X):
dX = sparse.dia_matrix((d, 0), shape=(d.size, d.size)).dot(X)
else:
# Precompute as much as possible
dX = d[:, np.newaxis] * X

def Hs(s):
ret = dX.T.dot(dX.dot(s))
ret += alpha * s
return ret
#print 'Loss/grad/hess %r, %r' % (out, grad.dot(grad))
return out, grad, Hs

n_samples, n_features = X.shape
fit_intercept = False
c = 0
grad = np.empty_like(w)

def _logistic_loss_and_grad_intercept(w_c, X, y, alpha):
w = w_c[:-1]
c = w_c[-1]
if w.size == n_features + 1:
fit_intercept = True
c = w[-1]
w = w[:-1]

z = X.dot(w)
z = safe_sparse_dot(X, w)
z += c
yz = y * z
out = np.empty_like(yz)
idx = yz > 0
out[idx] = np.log(1 + np.exp(-yz[idx]))
out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx])))
out = out.sum() + .5 * alpha * w.dot(w)

z = special.expit(yz)
z0 = (z - 1) * y
grad = np.empty_like(w_c)
grad[:-1] = X.T.dot(z0) + alpha * w
grad[-1] = z0.sum()
return out, grad


def _logistic_loss_intercept(w_c, X, y, alpha):
w = w_c[:-1]
c = w_c[-1]

z = X.dot(w)
z += c
yz = y * z
out = np.empty_like(yz)
idx = yz > 0
out[idx] = np.log(1 + np.exp(-yz[idx]))
out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx])))
out = out.sum() + .5 * alpha * w.dot(w)

#print 'Loss %r' % out
return out


def _logistic_loss_grad_hess_intercept(w_c, X, y, alpha):
w = w_c[:-1]
c = w_c[-1]

z = X.dot(w)
z += c
yz = y * z
out = np.empty_like(yz)
idx = yz > 0
out[idx] = np.log(1 + np.exp(-yz[idx]))
out[~idx] = (-yz[~idx] + np.log(1 + np.exp(yz[~idx])))
out = out.sum() + .5 * alpha * w.dot(w)
# Logistic loss is the negative of the log of the logistic function.
out = -np.sum(log_logistic(yz)) + .5 * alpha * np.dot(w, w)

z = special.expit(yz)
z0 = (z - 1) * y
grad = np.empty_like(w_c)
grad[:-1] = X.T.dot(z0) + alpha * w
z0_sum = z0.sum()
grad[-1] = z0_sum
grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w
if fit_intercept:
z0_sum = np.sum(z0)
grad[-1] = np.sum(z0)

# The mat-vec product of the Hessian
d = z * (1 - z)
d = np.sqrt(d, out=d)
if sparse.issparse(X):
dX = sparse.dia_matrix((d, 0), shape=(d.size, d.size)).dot(X)
dX = safe_sparse_dot(sparse.dia_matrix((d, 0),
shape=(n_samples, n_samples)), X)
else:
# Precompute as much as possible
dX = d[:, np.newaxis] * X

def Hs(s):
ret = np.empty_like(s)
ret[:-1] = dX.T.dot(dX.dot(s[:-1]))
ret[:-1] += alpha * s[:-1]
# XXX: I am not sure that this last line of the Hessian is right
# Without the intercept the Hessian is right, though
ret[-1] = z0_sum * s[-1]
ret[:n_features] = dX.T.dot(dX.dot(s[:n_features]))
ret[:n_features] += alpha * s[:n_features]
if fit_intercept:
# XXX: Is this right?
ret[-1] = z0_sum * s[-1]
return ret

#print 'Loss/grad/hess %r, %r' % (out, grad.dot(grad))
return out, grad, Hs


def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
max_iter=100, gtol=1e-4, verbose=0,
solver='liblinear', callback=None,
Expand Down Expand Up @@ -239,10 +201,8 @@ def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
'implemented for the binary class case')
if fit_intercept:
w0 = np.zeros(X.shape[1] + 1)
func = _logistic_loss_and_grad_intercept
else:
w0 = np.zeros(X.shape[1])
func = _logistic_loss_and_grad

if coef is not None:
# it must work both giving the bias term and not
Expand All @@ -255,6 +215,7 @@ def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
if callback is not None:
callback(w0, X, y, 1. / C)
if solver == 'lbfgs':
func = _logistic_loss_and_grad
try:
out = optimize.fmin_l_bfgs_b(
func, w0, fprime=None,
Expand All @@ -268,17 +229,9 @@ def logistic_regression_path(X, y, Cs=10, fit_intercept=True,
iprint=verbose > 0, pgtol=gtol)
w0 = out[0]
elif solver == 'newton-cg':
if fit_intercept:
func_grad_hess = _logistic_loss_grad_hess_intercept
func = _logistic_loss_intercept
grad = lambda x, *args: _logistic_loss_and_grad_intercept(x, *args)[1]
else:
func_grad_hess = _logistic_loss_grad_hess
func = _logistic_loss
grad = lambda x, *args: _logistic_loss_and_grad(x, *args)[1]

w0 = newton_cg(func_grad_hess, func, grad, w0, args=(X, y, 1./C),
maxiter=max_iter)
grad = lambda x, *args: _logistic_loss_and_grad(x, *args)[1]
w0 = newton_cg(_logistic_loss_grad_hess, _logistic_loss, grad,
w0, args=(X, y, 1./C), maxiter=max_iter)
elif solver == 'liblinear':
lr = LogisticRegression(C=C, fit_intercept=fit_intercept, tol=gtol)
lr.fit(X, y)
Expand Down
11 changes: 5 additions & 6 deletions sklearn/linear_model/tests/test_logistic.py
Expand Up @@ -12,8 +12,7 @@

from sklearn.linear_model.logistic import (LogisticRegression,
logistic_regression_path, LogisticRegressionCV,
_logistic_loss_and_grad, _logistic_loss_and_grad_intercept,
_logistic_loss_grad_hess, _logistic_loss_grad_hess_intercept)
_logistic_loss_and_grad, _logistic_loss_grad_hess)
from sklearn.datasets import load_iris, make_classification

X = [[-1, 0], [0, 1], [1, 1]]
Expand Down Expand Up @@ -203,12 +202,12 @@ def test__logistic_loss_and_grad():

# Second check that our intercept implementation is good
w = np.zeros(n_features + 1)
loss_interp, grad_interp = _logistic_loss_and_grad_intercept(w,
loss_interp, grad_interp = _logistic_loss_and_grad(w,
X, y, alpha=1.)
assert_array_almost_equal(loss, loss_interp)

approx_grad = optimize.approx_fprime(w,
lambda w: _logistic_loss_and_grad_intercept(
lambda w: _logistic_loss_and_grad(
w, X, y, alpha=1.)[0], 1e-3)
assert_array_almost_equal(grad_interp, approx_grad, decimal=2)

Expand Down Expand Up @@ -258,10 +257,10 @@ def test__logistic_loss_grad_hess():

# Second check that our intercept implementation is good
w = np.zeros(n_features + 1)
loss_interp, grad_interp = _logistic_loss_and_grad_intercept(w,
loss_interp, grad_interp = _logistic_loss_and_grad(w,
X, y, alpha=1.)
loss_interp_2, grad_interp_2, hess = \
_logistic_loss_grad_hess_intercept(w, X, y, alpha=1.)
_logistic_loss_grad_hess(w, X, y, alpha=1.)
assert_array_almost_equal(loss_interp, loss_interp_2)
assert_array_almost_equal(grad_interp, grad_interp_2)

Expand Down

0 comments on commit 137d40b

Please sign in to comment.