Skip to content

Commit

Permalink
FIX force pipeline steps to be list not a tuple (scikit-learn#9604)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche authored and AishwaryaRK committed Aug 29, 2017
1 parent f9d5e8a commit 4b99bdf
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 4 deletions.
5 changes: 2 additions & 3 deletions sklearn/pipeline.py
Expand Up @@ -17,7 +17,6 @@
from .base import clone, TransformerMixin
from .externals.joblib import Parallel, delayed, Memory
from .externals import six
from .utils import tosequence
from .utils.metaestimators import if_delegate_has_method
from .utils import Bunch

Expand Down Expand Up @@ -112,7 +111,7 @@ class Pipeline(_BaseComposition):

def __init__(self, steps, memory=None):
# shallow copy of steps
self.steps = tosequence(steps)
self.steps = list(steps)
self._validate_steps()
self.memory = memory

Expand Down Expand Up @@ -624,7 +623,7 @@ class FeatureUnion(_BaseComposition, TransformerMixin):
"""
def __init__(self, transformer_list, n_jobs=1, transformer_weights=None):
self.transformer_list = tosequence(transformer_list)
self.transformer_list = list(transformer_list)
self.n_jobs = n_jobs
self.transformer_weights = transformer_weights
self._validate_transformers()
Expand Down
16 changes: 16 additions & 0 deletions sklearn/tests/test_pipeline.py
Expand Up @@ -208,6 +208,18 @@ def test_pipeline_init():
assert_equal(params, params2)


def test_pipeline_init_tuple():
# Pipeline accepts steps as tuple
X = np.array([[1, 2]])
pipe = Pipeline((('transf', Transf()), ('clf', FitParamT())))
pipe.fit(X, y=None)
pipe.score(X)

pipe.set_params(transf=None)
pipe.fit(X, y=None)
pipe.score(X)


def test_pipeline_methods_anova():
# Test the various methods of the pipeline (anova).
iris = load_iris()
Expand Down Expand Up @@ -425,6 +437,10 @@ def test_feature_union():
FeatureUnion,
[("transform", Transf()), ("no_transform", NoTrans())])

# test that init accepts tuples
fs = FeatureUnion((("svd", svd), ("select", select)))
fs.fit(X, y)


def test_make_union():
pca = PCA(svd_solver='full')
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/metaestimators.py
Expand Up @@ -51,7 +51,7 @@ def _set_params(self, attr, **params):

def _replace_estimator(self, attr, name, new_val):
# assumes `name` is a valid estimator name
new_estimators = getattr(self, attr)[:]
new_estimators = list(getattr(self, attr))
for i, (estimator_name, _) in enumerate(new_estimators):
if estimator_name == name:
new_estimators[i] = (name, new_val)
Expand Down

0 comments on commit 4b99bdf

Please sign in to comment.