Skip to content

Commit

Permalink
TEST Cover all liblinear input formats in test_dtype_match (scikit-le…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhenrie committed Jul 24, 2019
1 parent 4ccb830 commit c34946b
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions sklearn/linear_model/tests/test_logistic.py
Expand Up @@ -1295,34 +1295,48 @@ def test_saga_vs_liblinear():


@pytest.mark.parametrize('multi_class', ['ovr', 'multinomial'])
@pytest.mark.parametrize('solver', ['newton-cg', 'saga'])
def test_dtype_match(solver, multi_class):
@pytest.mark.parametrize('solver', ['newton-cg', 'liblinear', 'saga'])
@pytest.mark.parametrize('fit_intercept', [False, True])
def test_dtype_match(solver, multi_class, fit_intercept):
# Test that np.float32 input data is not cast to np.float64 when possible
# and that the output is approximately the same no matter the input format.

if solver == 'liblinear' and multi_class == 'multinomial':
pytest.skip('liblinear does not support multinomial classes')

out32_type = np.float64 if solver == 'liblinear' else np.float32

X_32 = np.array(X).astype(np.float32)
y_32 = np.array(Y1).astype(np.float32)
X_64 = np.array(X).astype(np.float64)
y_64 = np.array(Y1).astype(np.float64)
X_sparse_32 = sp.csr_matrix(X, dtype=np.float32)
X_sparse_64 = sp.csr_matrix(X, dtype=np.float64)
solver_tol = 5e-4

lr_templ = LogisticRegression(
solver=solver, multi_class=multi_class,
random_state=42, tol=solver_tol, fit_intercept=True)
# Check type consistency
random_state=42, tol=solver_tol, fit_intercept=fit_intercept)

# Check 32-bit type consistency
lr_32 = clone(lr_templ)
lr_32.fit(X_32, y_32)
assert lr_32.coef_.dtype == X_32.dtype
assert lr_32.coef_.dtype == out32_type

# check consistency with sparsity
# Check 32-bit type consistency with sparsity
lr_32_sparse = clone(lr_templ)
lr_32_sparse.fit(X_sparse_32, y_32)
assert lr_32_sparse.coef_.dtype == X_sparse_32.dtype
assert lr_32_sparse.coef_.dtype == out32_type

# Check accuracy consistency
# Check 64-bit type consistency
lr_64 = clone(lr_templ)
lr_64.fit(X_64, y_64)
assert lr_64.coef_.dtype == X_64.dtype
assert lr_64.coef_.dtype == np.float64

# Check 64-bit type consistency with sparsity
lr_64_sparse = clone(lr_templ)
lr_64_sparse.fit(X_sparse_64, y_64)
assert lr_64_sparse.coef_.dtype == np.float64

# solver_tol bounds the norm of the loss gradient
# dw ~= inv(H)*grad ==> |dw| ~= |inv(H)| * solver_tol, where H - hessian
Expand All @@ -1339,8 +1353,17 @@ def test_dtype_match(solver, multi_class):
# FIXME
atol = 1e-2

# Check accuracy consistency
assert_allclose(lr_32.coef_, lr_64.coef_.astype(np.float32), atol=atol)

if solver == 'saga' and fit_intercept:
# FIXME: SAGA on sparse data fits the intercept inaccurately with the
# default tol and max_iter parameters.
atol = 1e-1

assert_allclose(lr_32.coef_, lr_32_sparse.coef_, atol=atol)
assert_allclose(lr_64.coef_, lr_64_sparse.coef_, atol=atol)


def test_warm_start_converge_LR():
# Test to see that the logistic regression converges on warm start,
Expand Down

0 comments on commit c34946b

Please sign in to comment.