Skip to content

Commit

Permalink
[MRG+2] LogisticRegression convert to float64 (newton-cg) (scikit-lea…
Browse files Browse the repository at this point in the history
…rn#8835)

* Add a test to ensure not changing the input's data type

Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg

* [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not)

* [WIP] ensure self.coef_ same type as X

* keep the np.float32 when multi_class='multinomial'

* Avoid hardcoded type for multinomial

* pass flake8

* Ensure that the results in 32bits are the same as in 64

* Address Gael's comments for multi_class=='ovr'

* Add multi_class=='multinominal' to test

* Add support for multi_class=='multinominal'

* prefer float64 to float32

* Force X and y to have the same type

* Revert "Add support for multi_class=='multinominal'"

This reverts commit 4ac33e8.

* remvert more stuff

* clean up some commmented code

* allow class_weight to take advantage of float32

* Add a test where X.dtype is different of y.dtype

* Address @raghavrv comments

* address the rest of @raghavrv's comments

* Revert class_weight

* Avoid copying if dtype matches

* Address alex comment to the cast from inside _multinomial_loss_grad

* address alex comment

* add sparsity test

* Addressed Tom comment of checking that we keep the 64 aswell
  • Loading branch information
massich authored and AishwaryaRK committed Aug 29, 2017
1 parent 3e72a6e commit 5ab17a4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
22 changes: 14 additions & 8 deletions sklearn/linear_model/logistic.py
Expand Up @@ -338,7 +338,8 @@ def _multinomial_loss_grad(w, X, Y, alpha, sample_weight):
n_classes = Y.shape[1]
n_features = X.shape[1]
fit_intercept = (w.size == n_classes * (n_features + 1))
grad = np.zeros((n_classes, n_features + bool(fit_intercept)))
grad = np.zeros((n_classes, n_features + bool(fit_intercept)),
dtype=X.dtype)
loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight)
sample_weight = sample_weight[:, np.newaxis]
diff = sample_weight * (p - Y)
Expand Down Expand Up @@ -609,10 +610,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
# and check length
# Otherwise set them to 1 for all examples
if sample_weight is not None:
sample_weight = np.array(sample_weight, dtype=np.float64, order='C')
sample_weight = np.array(sample_weight, dtype=X.dtype, order='C')
check_consistent_length(y, sample_weight)
else:
sample_weight = np.ones(X.shape[0])
sample_weight = np.ones(X.shape[0], dtype=X.dtype)

# If class_weights is a dict (provided by the user), the weights
# are assigned to the original labels. If it is "balanced", then
Expand All @@ -625,10 +626,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
# For doing a ovr, we need to mask the labels first. for the
# multinomial case this is not necessary.
if multi_class == 'ovr':
w0 = np.zeros(n_features + int(fit_intercept))
w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype)
mask_classes = np.array([-1, 1])
mask = (y == pos_class)
y_bin = np.ones(y.shape, dtype=np.float64)
y_bin = np.ones(y.shape, dtype=X.dtype)
y_bin[~mask] = -1.
# for compute_class_weight

Expand All @@ -646,10 +647,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,
else:
# SAG multinomial solver needs LabelEncoder, not LabelBinarizer
le = LabelEncoder()
Y_multi = le.fit_transform(y)
Y_multi = le.fit_transform(y).astype(X.dtype, copy=False)

w0 = np.zeros((classes.size, n_features + int(fit_intercept)),
order='F')
order='F', dtype=X.dtype)

if coef is not None:
# it must work both giving the bias term and not
Expand Down Expand Up @@ -1204,7 +1205,12 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Tolerance for stopping criteria must be "
"positive; got (tol=%r)" % self.tol)

X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64,
if self.solver in ['newton-cg']:
_dtype = [np.float64, np.float32]
else:
_dtype = np.float64

X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype,
order="C")
check_classification_targets(y)
self.classes_ = np.unique(y)
Expand Down
30 changes: 30 additions & 0 deletions sklearn/linear_model/tests/test_logistic.py
Expand Up @@ -1129,3 +1129,33 @@ def test_saga_vs_liblinear():
liblinear.fit(X, y)
# Convergence for alpha=1e-3 is very slow
assert_array_almost_equal(saga.coef_, liblinear.coef_, 3)


def test_dtype_match():
# Test that np.float32 input data is not cast to np.float64 when possible

X_32 = np.array(X).astype(np.float32)
y_32 = np.array(Y1).astype(np.float32)
X_64 = np.array(X).astype(np.float64)
y_64 = np.array(Y1).astype(np.float64)
X_sparse_32 = sp.csr_matrix(X, dtype=np.float32)

for solver in ['newton-cg']:
for multi_class in ['ovr', 'multinomial']:

# Check type consistency
lr_32 = LogisticRegression(solver=solver, multi_class=multi_class)
lr_32.fit(X_32, y_32)
assert_equal(lr_32.coef_.dtype, X_32.dtype)

# check consistency with sparsity
lr_32_sparse = LogisticRegression(solver=solver,
multi_class=multi_class)
lr_32_sparse.fit(X_sparse_32, y_32)
assert_equal(lr_32_sparse.coef_.dtype, X_sparse_32.dtype)

# Check accuracy consistency
lr_64 = LogisticRegression(solver=solver, multi_class=multi_class)
lr_64.fit(X_64, y_64)
assert_equal(lr_64.coef_.dtype, X_64.dtype)
assert_almost_equal(lr_32.coef_, lr_64.coef_.astype(np.float32))

0 comments on commit 5ab17a4

Please sign in to comment.