forked from scikit-learn/scikit-learn
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG+2] LogisticRegression convert to float64 (newton-cg) (scikit-lea…
…rn#8835) * Add a test to ensure not changing the input's data type Test that np.float32 input data is not cast to np.float64 when using LR + newton-cg * [WIP] Force X to remain float32. (self.coef_ remains float64 even if X is not) * [WIP] ensure self.coef_ same type as X * keep the np.float32 when multi_class='multinomial' * Avoid hardcoded type for multinomial * pass flake8 * Ensure that the results in 32bits are the same as in 64 * Address Gael's comments for multi_class=='ovr' * Add multi_class=='multinominal' to test * Add support for multi_class=='multinominal' * prefer float64 to float32 * Force X and y to have the same type * Revert "Add support for multi_class=='multinominal'" This reverts commit 4ac33e8. * remvert more stuff * clean up some commmented code * allow class_weight to take advantage of float32 * Add a test where X.dtype is different of y.dtype * Address @raghavrv comments * address the rest of @raghavrv's comments * Revert class_weight * Avoid copying if dtype matches * Address alex comment to the cast from inside _multinomial_loss_grad * address alex comment * add sparsity test * Addressed Tom comment of checking that we keep the 64 aswell
- Loading branch information
1 parent
3e72a6e
commit 5ab17a4
Showing
2 changed files
with
44 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters