Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace combine_dfs operator functionality with sklearn's FeatureUnion #117

Closed
rhiever opened this issue Mar 20, 2016 · 23 comments
Closed

Comments

@rhiever
Copy link
Contributor

rhiever commented Mar 20, 2016

Currently, combine_dfs uses custom code to combine the features from separate pipelines into a single feature set. We should instead use sklearn's FeatureUnion function within combine_dfs, which I believe will do a better and more efficient job of combining the features.

Here's an example provided by @amueller:

Pipeline(make_union(PolynomialFeatures(), PCA()), RFE(RandomForestClassifier()))
@rhiever
Copy link
Contributor Author

rhiever commented Jun 1, 2016

@teaearlgraycold, do you remember why we decided this wasn't feasible with FeatureUnion? We should document that here.

@danthedaniel
Copy link
Contributor

IIRC it only works on feature preprocessors. You can't pass it classifiers.

Wherever it was, it should be apparent from the documentation.

On Wed, Jun 1, 2016, 6:37 PM Randy Olson notifications@github.com wrote:

@teaearlgraycold https://github.com/teaearlgraycold, do you remember
why we decided this wasn't feasible with FeatureUnion? We should document
that here.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#117 (comment), or mute
the thread
https://github.com/notifications/unsubscribe/ADISYywOFmZVhr2HaI5MFbvIE2B8XAkcks5qHgm0gaJpZM4H0qUh
.

@rhiever
Copy link
Contributor Author

rhiever commented Jun 1, 2016

Ah, right: FeatureUnion only accepts sklearn preprocessors that have a transform() function. So unless there's an easy way to wrap the sklearn classifiers such that they have a transform() function that simply adds the classifier's predictions as a new feature...

@rhiever
Copy link
Contributor Author

rhiever commented Jun 1, 2016

@amueller, any ideas on how to make it such that sklearn Classifiers can be included (as discussed above) in FeatureUnions?

@amueller
Copy link

amueller commented Jun 2, 2016

wait depends what you want. for feature selection? Did I provide the example in the first post? Seems odd lol

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

Let's use this pipeline as an example:

image

but instead of PCA there's a Random Forest there. So what this pipeline does is:

  1. Take a copy of the data set and apply Polynomial Features to create data set A

  2. Take another copy of the data set, fit a Random Forest, and take the predictions of the Random Forest and add them to the data set as a new feature to create data set B

  3. Combine the features of data sets A and B into a single data set

  4. Apply RFE

  5. Fit another Random Forest to the features left after RFE and use that Random Forest's predictions as the final prediction for the pipeline

So we're looking for a sklearn-compatible way to represent that as a pipeline. We originally thought we could do:

make_pipeline(make_union(PolynomialFeatures(), RandomForestClassifier()), RFE(), RandomForestClassifier())

but I'm pretty sure that doesn't work out of the box.

@amueller
Copy link

amueller commented Jun 2, 2016

and take the predictions of the Random Forest

That was the part I wasn't sure about. Ok then VotingClassifier.

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

How would that look in sklearn code? Like this?

make_pipeline(make_union(PolynomialFeatures(), VotingClassifier(estimators=['rf1', RandomForestClassifier()])), RFE(), RandomForestClassifier())

@amueller
Copy link

amueller commented Jun 2, 2016

yeah only that the RFE is around the last RandomForestClassifier.

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

I'm not sure what you mean?

@amueller
Copy link

amueller commented Jun 2, 2016

make_pipeline(make_union(PolynomialFeatures(), VotingClassifier(estimators=['rf1', RandomForestClassifier()])), RFE(RandomForestClassifier()))

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

Wait, why? Is that specific to RFE or feature selection methods?

@amueller
Copy link

amueller commented Jun 2, 2016

RFE is a model-based feature selection. How should it do feature selection without a model? SelectFromModel is also a meta-estimator, while the feature selection methods that are not model based are not.

@amueller
Copy link

amueller commented Jun 2, 2016

Btw, this is feature selection using RF. That doesn't necessarily imply classification with RF.

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

So we could do something like:

make_pipeline(make_union(PolynomialFeatures(),
VotingClassifier(estimators=['rf1', RandomForestClassifier()])),
RFE(estimator=SVC(kernel='linear')),
RandomForestClassifier())

if we wanted RF classification at the end of the pipeline.

@amueller
Copy link

amueller commented Jun 2, 2016

RFE has a predict, so both of the pipelines you outlined can predict. They do different things, though.

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

Ahhh, I get it now. Just looked at the docs for the RFE predict function. Basically:

make_pipeline(make_union(PolynomialFeatures(),
VotingClassifier(estimators=['rf1', RandomForestClassifier()])),
RFE(estimator=RandomForestClassifier()))

and

make_pipeline(make_union(PolynomialFeatures(),
VotingClassifier(estimators=['rf1', RandomForestClassifier()])),
RFE(estimator=RandomForestClassifier()),
RandomForestClassifier())

would do the same thing.

@amueller
Copy link

amueller commented Jun 2, 2016

yes

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

That's awesome. Looks like we'll be able to export to sklearn pipelines after all, @teaearlgraycold!

@rhiever
Copy link
Contributor Author

rhiever commented Jun 2, 2016

Here's some example code that works:

from sklearn.pipeline import make_pipeline, make_union
from sklearn.preprocessing import PolynomialFeatures
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.feature_selection import SelectKBest, VarianceThreshold
from sklearn.datasets import load_digits
from sklearn.cross_validation import cross_val_score

data = load_digits()

clf = make_pipeline(make_union(PolynomialFeatures(),
                               VotingClassifier(estimators=[('rf1', RandomForestClassifier())])),
                    VarianceThreshold(),
                    SelectKBest(k=5),
                    RandomForestClassifier())

cross_val_score(clf, data.data, data.target, cv=5)

@danthedaniel
Copy link
Contributor

I'll keep that in mind when I start working on the refactored code's export utils

@rhiever
Copy link
Contributor Author

rhiever commented Aug 19, 2016

This feature will be in the 0.5 release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants