Skip to content

Commit

Permalink
Squash all the PR 9040 commits
Browse files Browse the repository at this point in the history
initial PR commit

seq_dataset.pyx generated from template

seq_dataset.pyx generated from template #2

rename variables

fused types consistency test for seq_dataset

a

sklearn/utils/tests/test_seq_dataset.py

new if statement

add doc

sklearn/utils/seq_dataset.pyx.tp

minor changes

minor changes

typo fix

check numeric accuracy only up 5th decimal

Address oliver's request for changing test name

add test for make_dataset and rename a variable in test_seq_dataset
  • Loading branch information
Imbert Arthur authored and NelleV committed Jul 8, 2018
1 parent c4e223f commit 9463e21
Show file tree
Hide file tree
Showing 16 changed files with 726 additions and 247 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Expand Up @@ -68,3 +68,8 @@ benchmarks/bench_covertype_data/
.cache
.pytest_cache/
_configtest.o.d

# files generated from a template
sklearn/utils/seq_dataset.pyx
sklearn/utils/seq_dataset.pxd
sklearn/linear_model/sag_fast.pyx
19 changes: 14 additions & 5 deletions sklearn/linear_model/base.py
Expand Up @@ -32,7 +32,9 @@
from ..utils.extmath import safe_sparse_dot
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
from ..utils.fixes import sparse_lsqr
from ..utils.seq_dataset import ArrayDataset, CSRDataset
from ..utils.seq_dataset import ArrayDataset32, CSRDataset32
from ..utils.seq_dataset import ArrayDataset64 as ArrayDataset
from ..utils.seq_dataset import CSRDataset64 as CSRDataset
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
from ..preprocessing.data import normalize as f_normalize
Expand All @@ -53,15 +55,22 @@ def make_dataset(X, y, sample_weight, random_state=None):
"""

rng = check_random_state(random_state)
# seed should never be 0 in SequentialDataset
# seed should never be 0 in SequentialDataset64
seed = rng.randint(1, np.iinfo(np.int32).max)

if X.dtype == np.float32:
CSRData = CSRDataset32
ArrayData = ArrayDataset32
else:
CSRData = CSRDataset
ArrayData = ArrayDataset

if sp.issparse(X):
dataset = CSRDataset(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
dataset = CSRData(X.data, X.indptr, X.indices, y, sample_weight,
seed=seed)
intercept_decay = SPARSE_INTERCEPT_DECAY
else:
dataset = ArrayDataset(X, y, sample_weight, seed=seed)
dataset = ArrayData(X, y, sample_weight, seed=seed)
intercept_decay = 1.0

return dataset, intercept_decay
Expand Down
8 changes: 4 additions & 4 deletions sklearn/linear_model/logistic.py
Expand Up @@ -737,7 +737,7 @@ def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True,

elif solver in ['sag', 'saga']:
if multi_class == 'multinomial':
target = target.astype(np.float64)
target = target.astype(X.dtype, copy=False)
loss = 'multinomial'
else:
loss = 'log'
Expand Down Expand Up @@ -1216,10 +1216,10 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("Tolerance for stopping criteria must be "
"positive; got (tol=%r)" % self.tol)

if self.solver in ['newton-cg']:
_dtype = [np.float64, np.float32]
else:
if self.solver in ['lbfgs', 'liblinear']:
_dtype = np.float64
else:
_dtype = [np.float64, np.float32]

X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C",
accept_large_sparse=self.solver != 'liblinear')
Expand Down
21 changes: 12 additions & 9 deletions sklearn/linear_model/sag.py
Expand Up @@ -9,7 +9,7 @@
import numpy as np

from .base import make_dataset
from .sag_fast import sag
from .sag_fast import sag32, sag64
from ..exceptions import ConvergenceWarning
from ..utils import check_array
from ..utils.extmath import row_norms
Expand Down Expand Up @@ -243,8 +243,9 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0.,
max_iter = 1000

if check_input:
X = check_array(X, dtype=np.float64, accept_sparse='csr', order='C')
y = check_array(y, dtype=np.float64, ensure_2d=False, order='C')
_dtype = [np.float64, np.float32]
X = check_array(X, dtype=_dtype, accept_sparse='csr', order='C')
y = check_array(y, dtype=_dtype, ensure_2d=False, order='C')

n_samples, n_features = X.shape[0], X.shape[1]
# As in SGD, the alpha is scaled by n_samples.
Expand All @@ -256,13 +257,13 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0.,

# initialization
if sample_weight is None:
sample_weight = np.ones(n_samples, dtype=np.float64, order='C')
sample_weight = np.ones(n_samples, dtype=X.dtype, order='C')

if 'coef' in warm_start_mem.keys():
coef_init = warm_start_mem['coef']
else:
# assume fit_intercept is False
coef_init = np.zeros((n_features, n_classes), dtype=np.float64,
coef_init = np.zeros((n_features, n_classes), dtype=X.dtype,
order='C')

# coef_init contains possibly the intercept_init at the end.
Expand All @@ -272,23 +273,23 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0.,
intercept_init = coef_init[-1, :]
coef_init = coef_init[:-1, :]
else:
intercept_init = np.zeros(n_classes, dtype=np.float64)
intercept_init = np.zeros(n_classes, dtype=X.dtype)

if 'intercept_sum_gradient' in warm_start_mem.keys():
intercept_sum_gradient = warm_start_mem['intercept_sum_gradient']
else:
intercept_sum_gradient = np.zeros(n_classes, dtype=np.float64)
intercept_sum_gradient = np.zeros(n_classes, dtype=X.dtype)

if 'gradient_memory' in warm_start_mem.keys():
gradient_memory_init = warm_start_mem['gradient_memory']
else:
gradient_memory_init = np.zeros((n_samples, n_classes),
dtype=np.float64, order='C')
dtype=X.dtype, order='C')
if 'sum_gradient' in warm_start_mem.keys():
sum_gradient_init = warm_start_mem['sum_gradient']
else:
sum_gradient_init = np.zeros((n_features, n_classes),
dtype=np.float64, order='C')
dtype=X.dtype, order='C')

if 'seen' in warm_start_mem.keys():
seen_init = warm_start_mem['seen']
Expand All @@ -311,6 +312,7 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0.,
raise ZeroDivisionError("Current sag implementation does not handle "
"the case step_size * alpha_scaled == 1")

sag = sag64 if X.dtype == np.float64 else sag32
num_seen, n_iter_ = sag(dataset, coef_init,
intercept_init, n_samples,
n_features, n_classes, tol,
Expand All @@ -327,6 +329,7 @@ def sag_solver(X, y, sample_weight=None, loss='log', alpha=1., beta=0.,
intercept_decay,
is_saga,
verbose)

if n_iter_ == max_iter:
warnings.warn("The max_iter was reached which means "
"the coef_ did not converge", ConvergenceWarning)
Expand Down

0 comments on commit 9463e21

Please sign in to comment.