diff --git a/sklearn/linear_model/logistic.py b/sklearn/linear_model/logistic.py index c649327aedca3..657c3118010ba 100644 --- a/sklearn/linear_model/logistic.py +++ b/sklearn/linear_model/logistic.py @@ -338,7 +338,8 @@ def _multinomial_loss_grad(w, X, Y, alpha, sample_weight): n_classes = Y.shape[1] n_features = X.shape[1] fit_intercept = (w.size == n_classes * (n_features + 1)) - grad = np.zeros((n_classes, n_features + bool(fit_intercept))) + grad = np.zeros((n_classes, n_features + bool(fit_intercept)), + dtype=X.dtype) loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight) sample_weight = sample_weight[:, np.newaxis] diff = sample_weight * (p - Y) @@ -609,10 +610,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, # and check length # Otherwise set them to 1 for all examples if sample_weight is not None: - sample_weight = np.array(sample_weight, dtype=np.float64, order='C') + sample_weight = np.array(sample_weight, dtype=X.dtype, order='C') check_consistent_length(y, sample_weight) else: - sample_weight = np.ones(X.shape[0]) + sample_weight = np.ones(X.shape[0], dtype=X.dtype) # If class_weights is a dict (provided by the user), the weights # are assigned to the original labels. If it is "balanced", then @@ -625,10 +626,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, # For doing a ovr, we need to mask the labels first. for the # multinomial case this is not necessary. if multi_class == 'ovr': - w0 = np.zeros(n_features + int(fit_intercept)) + w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype) mask_classes = np.array([-1, 1]) mask = (y == pos_class) - y_bin = np.ones(y.shape, dtype=np.float64) + y_bin = np.ones(y.shape, dtype=X.dtype) y_bin[~mask] = -1. # for compute_class_weight @@ -646,10 +647,10 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, else: # SAG multinomial solver needs LabelEncoder, not LabelBinarizer le = LabelEncoder() - Y_multi = le.fit_transform(y) + Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) w0 = np.zeros((classes.size, n_features + int(fit_intercept)), - order='F') + order='F', dtype=X.dtype) if coef is not None: # it must work both giving the bias term and not @@ -1204,7 +1205,12 @@ def fit(self, X, y, sample_weight=None): raise ValueError("Tolerance for stopping criteria must be " "positive; got (tol=%r)" % self.tol) - X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, + if self.solver in ['newton-cg']: + _dtype = [np.float64, np.float32] + else: + _dtype = np.float64 + + X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C") check_classification_targets(y) self.classes_ = np.unique(y) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 7cb7fdd2d7d32..c6f4fbf4a4c4d 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -1129,3 +1129,33 @@ def test_saga_vs_liblinear(): liblinear.fit(X, y) # Convergence for alpha=1e-3 is very slow assert_array_almost_equal(saga.coef_, liblinear.coef_, 3) + + +def test_dtype_match(): + # Test that np.float32 input data is not cast to np.float64 when possible + + X_32 = np.array(X).astype(np.float32) + y_32 = np.array(Y1).astype(np.float32) + X_64 = np.array(X).astype(np.float64) + y_64 = np.array(Y1).astype(np.float64) + X_sparse_32 = sp.csr_matrix(X, dtype=np.float32) + + for solver in ['newton-cg']: + for multi_class in ['ovr', 'multinomial']: + + # Check type consistency + lr_32 = LogisticRegression(solver=solver, multi_class=multi_class) + lr_32.fit(X_32, y_32) + assert_equal(lr_32.coef_.dtype, X_32.dtype) + + # check consistency with sparsity + lr_32_sparse = LogisticRegression(solver=solver, + multi_class=multi_class) + lr_32_sparse.fit(X_sparse_32, y_32) + assert_equal(lr_32_sparse.coef_.dtype, X_sparse_32.dtype) + + # Check accuracy consistency + lr_64 = LogisticRegression(solver=solver, multi_class=multi_class) + lr_64.fit(X_64, y_64) + assert_equal(lr_64.coef_.dtype, X_64.dtype) + assert_almost_equal(lr_32.coef_, lr_64.coef_.astype(np.float32))