Skip to content

Commit

Permalink
added support for lightning estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Feb 26, 2020
1 parent 3086e9a commit 2a3cc2c
Show file tree
Hide file tree
Showing 9 changed files with 485 additions and 69 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Expand Up @@ -16,6 +16,7 @@ before_install:
- bash .travis/setup.sh

install:
- pip install Cython
- pip install -r requirements-test.txt

script:
Expand Down
3 changes: 2 additions & 1 deletion Dockerfile
Expand Up @@ -26,6 +26,7 @@ RUN apt-get update && \
WORKDIR /m2cgen

COPY requirements-test.txt ./
RUN pip3 install --no-cache-dir -r requirements-test.txt
RUN pip3 install --no-cache-dir Cython && \
pip3 install --no-cache-dir -r requirements-test.txt

CMD python3 setup.py develop && pytest -v -x --fast
4 changes: 2 additions & 2 deletions README.md
Expand Up @@ -41,8 +41,8 @@ pip install m2cgen

| | Classification | Regression |
| --- | --- | --- |
| **Linear** | <ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Ordinary Least Squares (OLS)</li><li>Weighted Least Squares (WLS)</li></ul></li><ul> |
| **SVM** | <ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul> | <ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul> |
| **Linear** | <ul><li>scikit-learn<ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul></li><li>lightning<ul><li>AdaGradClassifier</li><li>CDClassifier</li><li>FistaClassifier</li><li>SAGAClassifier</li><li>SAGClassifier</li><li>SDCAClassifier</li><li>SGDClassifier</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Ordinary Least Squares (OLS)</li><li>Weighted Least Squares (WLS)</li></ul><li>lightning<ul><li>AdaGradRegressor</li><li>CDRegressor</li><li>FistaRegressor</li><li>SAGARegressor</li><li>SAGRegressor</li><li>SDCARegressor</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
| **Boosting** | <ul><li>LGBMClassifier(gbdt/dart/goss booster only)</li><li>XGBClassifier(gbtree/gblinear booster only)</li><ul> | <ul><li>LGBMRegressor(gbdt/dart/goss booster only)</li><li>XGBRegressor(gbtree/gblinear booster only)</li></ul> |
Expand Down
163 changes: 108 additions & 55 deletions m2cgen/assemblers/__init__.py
Expand Up @@ -25,75 +25,128 @@

SUPPORTED_MODELS = {
# LightGBM
"LGBMClassifier": LightGBMModelAssembler,
"LGBMRegressor": LightGBMModelAssembler,
"lightgbm.sklearn.LGBMClassifier": LightGBMModelAssembler,
"lightgbm.sklearn.LGBMRegressor": LightGBMModelAssembler,

# XGBoost
"XGBClassifier": XGBoostModelAssemblerSelector,
"XGBRFClassifier": XGBoostModelAssemblerSelector,
"XGBRegressor": XGBoostModelAssemblerSelector,
"XGBRFRegressor": XGBoostModelAssemblerSelector,

# SVM
"LinearSVC": SklearnLinearModelAssembler,
"LinearSVR": SklearnLinearModelAssembler,
"NuSVC": SVMModelAssembler,
"NuSVR": SVMModelAssembler,
"SVC": SVMModelAssembler,
"SVR": SVMModelAssembler,
"xgboost.sklearn.XGBClassifier": XGBoostModelAssemblerSelector,
"xgboost.sklearn.XGBRFClassifier": XGBoostModelAssemblerSelector,
"xgboost.sklearn.XGBRegressor": XGBoostModelAssemblerSelector,
"xgboost.sklearn.XGBRFRegressor": XGBoostModelAssemblerSelector,

# Sklearn SVM
"sklearn.svm.classes.LinearSVC": SklearnLinearModelAssembler,
"sklearn.svm.classes.LinearSVR": SklearnLinearModelAssembler,
"sklearn.svm.classes.NuSVC": SVMModelAssembler,
"sklearn.svm.classes.NuSVR": SVMModelAssembler,
"sklearn.svm.classes.LinearSVCSVC": SVMModelAssembler,
"sklearn.svm.classes.SVR": SVMModelAssembler,

# lightning SVM
"lightning.impl.dual_cd.LinearSVC": SklearnLinearModelAssembler,
"lightning.impl.dual_cd.LinearSVR": SklearnLinearModelAssembler,

# Sklearn Linear Regressors
"ARDRegression": SklearnLinearModelAssembler,
"BayesianRidge": SklearnLinearModelAssembler,
"ElasticNet": SklearnLinearModelAssembler,
"ElasticNetCV": SklearnLinearModelAssembler,
"HuberRegressor": SklearnLinearModelAssembler,
"Lars": SklearnLinearModelAssembler,
"LarsCV": SklearnLinearModelAssembler,
"Lasso": SklearnLinearModelAssembler,
"LassoCV": SklearnLinearModelAssembler,
"LassoLars": SklearnLinearModelAssembler,
"LassoLarsCV": SklearnLinearModelAssembler,
"LassoLarsIC": SklearnLinearModelAssembler,
"LinearRegression": SklearnLinearModelAssembler,
"OrthogonalMatchingPursuit": SklearnLinearModelAssembler,
"OrthogonalMatchingPursuitCV": SklearnLinearModelAssembler,
"PassiveAggressiveRegressor": SklearnLinearModelAssembler,
"RANSACRegressor": RANSACModelAssembler,
"Ridge": SklearnLinearModelAssembler,
"RidgeCV": SklearnLinearModelAssembler,
"SGDRegressor": SklearnLinearModelAssembler,
"TheilSenRegressor": SklearnLinearModelAssembler,
"sklearn.linear_model.bayes.ARDRegression": SklearnLinearModelAssembler,
"sklearn.linear_model.bayes.BayesianRidge": SklearnLinearModelAssembler,
"sklearn.linear_model.coordinate_descent.ElasticNet":
SklearnLinearModelAssembler,
"sklearn.linear_model.coordinate_descent.ElasticNetCV":
SklearnLinearModelAssembler,
"sklearn.linear_model.huber.HuberRegressor": SklearnLinearModelAssembler,
"sklearn.linear_model.least_angle.Lars": SklearnLinearModelAssembler,
"sklearn.linear_model.least_angle.LarsCV": SklearnLinearModelAssembler,
"sklearn.linear_model.coordinate_descent.Lasso":
SklearnLinearModelAssembler,
"sklearn.linear_model.coordinate_descent.LassoCV":
SklearnLinearModelAssembler,
"sklearn.linear_model.least_angle.LassoLars":
SklearnLinearModelAssembler,
"sklearn.linear_model.least_angle.LassoLarsCV":
SklearnLinearModelAssembler,
"sklearn.linear_model.least_angle.LassoLarsIC":
SklearnLinearModelAssembler,
"sklearn.linear_model.base.LinearRegression":
SklearnLinearModelAssembler,
"sklearn.linear_model.omp.OrthogonalMatchingPursuit":
SklearnLinearModelAssembler,
"sklearn.linear_model.omp.OrthogonalMatchingPursuitCV":
SklearnLinearModelAssembler,
"sklearn.linear_model.passive_aggressive.PassiveAggressiveRegressor":
SklearnLinearModelAssembler,
"sklearn.linear_model.ransac.RANSACRegressor": RANSACModelAssembler,
"sklearn.linear_model.ridge.Ridge": SklearnLinearModelAssembler,
"sklearn.linear_model.ridge.RidgeCV": SklearnLinearModelAssembler,
"sklearn.linear_model.stochastic_gradient.SGDRegressor":
SklearnLinearModelAssembler,
"sklearn.linear_model.theil_sen.TheilSenRegressor":
SklearnLinearModelAssembler,

# Statsmodels Linear Regressors
"RegressionResultsWrapper": StatsmodelsLinearModelAssembler,
"RegularizedResultsWrapper": StatsmodelsLinearModelAssembler,

# Linear Classifiers
"LogisticRegression": SklearnLinearModelAssembler,
"LogisticRegressionCV": SklearnLinearModelAssembler,
"PassiveAggressiveClassifier": SklearnLinearModelAssembler,
"Perceptron": SklearnLinearModelAssembler,
"RidgeClassifier": SklearnLinearModelAssembler,
"RidgeClassifierCV": SklearnLinearModelAssembler,
"SGDClassifier": SklearnLinearModelAssembler,
"statsmodels.regression.linear_model.RegressionResultsWrapper":
StatsmodelsLinearModelAssembler,
"statsmodels.base.elastic_net.RegularizedResultsWrapper":
StatsmodelsLinearModelAssembler,

# lightning Linear Regressors
"lightning.impl.adagrad.AdaGradRegressor": SklearnLinearModelAssembler,
"lightning.impl.primal_cd.CDRegressor": SklearnLinearModelAssembler,
"lightning.impl.fista.FistaRegressor": SklearnLinearModelAssembler,
"lightning.impl.sag.SAGARegressor": SklearnLinearModelAssembler,
"lightning.impl.sag.SAGRegressor": SklearnLinearModelAssembler,
"lightning.impl.sdca.SDCARegressor": SklearnLinearModelAssembler,

# Sklearn Linear Classifiers
"sklearn.linear_model.logistic.LogisticRegression":
SklearnLinearModelAssembler,
"sklearn.linear_model.logistic.LogisticRegressionCV":
SklearnLinearModelAssembler,
"sklearn.linear_model.passive_aggressive.PassiveAggressiveClassifier":
SklearnLinearModelAssembler,
"sklearn.linear_model.perceptron.Perceptron":
SklearnLinearModelAssembler,
"sklearn.linear_model.ridge.RidgeClassifier":
SklearnLinearModelAssembler,
"sklearn.linear_model.ridge.RidgeClassifierCV":
SklearnLinearModelAssembler,
"sklearn.linear_model.stochastic_gradient.SGDClassifier":
SklearnLinearModelAssembler,

# lightning Linear Classifiers
"lightning.impl.adagrad.AdaGradClassifier": SklearnLinearModelAssembler,
"lightning.impl.primal_cd.CDClassifier": SklearnLinearModelAssembler,
"lightning.impl.fista.FistaClassifier": SklearnLinearModelAssembler,
"lightning.impl.sag.SAGAClassifier": SklearnLinearModelAssembler,
"lightning.impl.sag.SAGClassifier": SklearnLinearModelAssembler,
"lightning.impl.sdca.SDCAClassifier": SklearnLinearModelAssembler,
"lightning.impl.sgd.SGDClassifier": SklearnLinearModelAssembler,

# Decision trees
"DecisionTreeClassifier": TreeModelAssembler,
"DecisionTreeRegressor": TreeModelAssembler,
"ExtraTreeClassifier": TreeModelAssembler,
"ExtraTreeRegressor": TreeModelAssembler,
"sklearn.tree.tree.DecisionTreeClassifier": TreeModelAssembler,
"sklearn.tree.tree.DecisionTreeRegressor": TreeModelAssembler,
"sklearn.tree.tree.ExtraTreeClassifier": TreeModelAssembler,
"sklearn.tree.tree.ExtraTreeRegressor": TreeModelAssembler,

# Ensembles
"ExtraTreesClassifier": RandomForestModelAssembler,
"ExtraTreesRegressor": RandomForestModelAssembler,
"RandomForestClassifier": RandomForestModelAssembler,
"RandomForestRegressor": RandomForestModelAssembler,
"sklearn.ensemble.forest.ExtraTreesClassifier":
RandomForestModelAssembler,
"sklearn.ensemble.forest.ExtraTreesRegressor":
RandomForestModelAssembler,
"sklearn.ensemble.forest.RandomForestClassifier":
RandomForestModelAssembler,
"sklearn.ensemble.forest.RandomForestRegressor":
RandomForestModelAssembler,
}


def _get_full_model_name(model):
type_name = type(model)
return "{}.{}".format(type_name.__module__,
type_name.__name__)


def get_assembler_cls(model):
model_name = type(model).__name__
model_name = _get_full_model_name(model)
assembler_cls = SUPPORTED_MODELS.get(model_name)

if not assembler_cls:
Expand Down
3 changes: 2 additions & 1 deletion m2cgen/assemblers/linear.py
Expand Up @@ -33,7 +33,8 @@ def _get_coef(self):
class SklearnLinearModelAssembler(BaseLinearModelAssembler):

def _get_intercept(self):
return self.model.intercept_
return getattr(self.model, "intercept_",
np.zeros(self._get_coef().shape[0]))

def _get_coef(self):
return self.model.coef_
Expand Down
3 changes: 2 additions & 1 deletion requirements-test.txt
@@ -1,4 +1,4 @@
numpy==1.15.1
numpy==1.16.1
scipy==1.1.0
scikit-learn==0.20.2
xgboost==0.90
Expand All @@ -10,3 +10,4 @@ coveralls==1.9.2
pytest-cov==2.8.1
py-mini-racer==0.1.18
statsmodels==0.10.2
git+git://github.com/scikit-learn-contrib/lightning.git@b96f9c674968496e854078163c8814049a7b9f43

0 comments on commit 2a3cc2c

Please sign in to comment.