Skip to content

Commit

Permalink
added support for lightning estimators
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Feb 24, 2020
1 parent 3086e9a commit 74932bd
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 12 deletions.
6 changes: 3 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ python:

env:
- TEST=API
- TEST=E2E LANG="c_lang or python or java or go_lang or javascript or php"
- TEST=E2E LANG="c_sharp or visual_basic or powershell"
- TEST=E2E LANG="r_lang"
- TEST=E2E LANG="python"
# - TEST=E2E LANG="c_sharp or visual_basic or powershell"
# - TEST=E2E LANG="r_lang"

before_install:
- bash .travis/setup.sh
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ pip install m2cgen

| | Classification | Regression |
| --- | --- | --- |
| **Linear** | <ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Ordinary Least Squares (OLS)</li><li>Weighted Least Squares (WLS)</li></ul></li><ul> |
| **SVM** | <ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul> | <ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul> |
| **Linear** | <ul><li>scikit-learn<ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul></li><li>lightning<ul><li>AdaGradClassifier</li><li>CDClassifier</li><li>FistaClassifier</li><li>SAGAClassifier</li><li>SAGClassifier</li><li>SDCAClassifier</li><li>SGDClassifier</li><li>SVRGClassifier</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Ordinary Least Squares (OLS)</li><li>Weighted Least Squares (WLS)</li></ul><li>lightning<ul><li>AdaGradRegressor</li><li>CDRegressor</li><li>FistaRegressor</li><li>SAGARegressor</li><li>SAGRegressor</li><li>SDCARegressor</li><li>SGDRegressor</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>KernelSVC</li><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
| **Boosting** | <ul><li>LGBMClassifier(gbdt/dart/goss booster only)</li><li>XGBClassifier(gbtree/gblinear booster only)</li><ul> | <ul><li>LGBMRegressor(gbdt/dart/goss booster only)</li><li>XGBRegressor(gbtree/gblinear booster only)</li></ul> |
Expand Down
21 changes: 20 additions & 1 deletion m2cgen/assemblers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,15 @@
"RegressionResultsWrapper": StatsmodelsLinearModelAssembler,
"RegularizedResultsWrapper": StatsmodelsLinearModelAssembler,

# Linear Classifiers
# lightning Linear Regressors
"AdaGradRegressor": SklearnLinearModelAssembler,
"CDRegressor": SklearnLinearModelAssembler,
"FistaRegressor": SklearnLinearModelAssembler,
"SAGARegressor": SklearnLinearModelAssembler,
"SAGRegressor": SklearnLinearModelAssembler,
"SDCARegressor": SklearnLinearModelAssembler,

# Sklearn Linear Classifiers
"LogisticRegression": SklearnLinearModelAssembler,
"LogisticRegressionCV": SklearnLinearModelAssembler,
"PassiveAggressiveClassifier": SklearnLinearModelAssembler,
Expand All @@ -78,6 +86,17 @@
"RidgeClassifierCV": SklearnLinearModelAssembler,
"SGDClassifier": SklearnLinearModelAssembler,

# lightning Linear Classifiers
"AdaGradClassifier": SklearnLinearModelAssembler,
"CDClassifier": SklearnLinearModelAssembler,
"FistaClassifier": SklearnLinearModelAssembler,
"KernelSVC": SklearnLinearModelAssembler,
"LinearSVC": SklearnLinearModelAssembler,
"SAGAClassifier": SklearnLinearModelAssembler,
"SAGClassifier": SklearnLinearModelAssembler,
"SDCAClassifier": SklearnLinearModelAssembler,
"SVRGClassifier": SklearnLinearModelAssembler,

# Decision trees
"DecisionTreeClassifier": TreeModelAssembler,
"DecisionTreeRegressor": TreeModelAssembler,
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/assemblers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _get_coef(self):
class SklearnLinearModelAssembler(BaseLinearModelAssembler):

def _get_intercept(self):
return self.model.intercept_
return getattr(self.model, "intercept_", 0)

def _get_coef(self):
return self.model.coef_
Expand Down
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ coveralls==1.9.2
pytest-cov==2.8.1
py-mini-racer==0.1.18
statsmodels==0.10.2
sklearn-contrib-lightning==0.5.0
Cython
77 changes: 72 additions & 5 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import xgboost
import statsmodels.api as sm
import lightning.classification as light_clf
import lightning.regression as light_reg
from sklearn import linear_model, svm
from sklearn import tree
from sklearn import ensemble
Expand Down Expand Up @@ -182,12 +184,17 @@ def classification_binary_random(model):
classification_binary_random(
xgboost.XGBClassifier(**XGBOOST_PARAMS_LARGE)),
# Linear SVM
# Sklearn Linear SVM
regression(svm.LinearSVR(random_state=RANDOM_SEED)),
classification(svm.LinearSVC(random_state=RANDOM_SEED)),
classification_binary(svm.LinearSVC(random_state=RANDOM_SEED)),
# SVM
# lightning Linear SVM
regression(light_reg.LinearSVR(random_state=RANDOM_SEED)),
classification(light_clf.LinearSVC(random_state=RANDOM_SEED)),
classification_binary(light_clf.LinearSVC(random_state=RANDOM_SEED)),
# Sklearn SVM
regression(svm.NuSVR(kernel="rbf")),
regression(svm.SVR(kernel="rbf")),
classification(svm.NuSVC(kernel="rbf", **SVC_PARAMS)),
Expand All @@ -198,6 +205,20 @@ def classification_binary_random(model):
classification_binary(svm.SVC(kernel="rbf", **SVC_PARAMS)),
classification_binary(svm.SVC(kernel="sigmoid", **SVC_PARAMS)),
# lightning SVM
classification(light_clf.KernelSVC(
kernel="rbf", random_state=RANDOM_SEED)),
classification_binary(light_clf.KernelSVC(
kernel="rbf", random_state=RANDOM_SEED)),
classification_binary(light_clf.KernelSVC(
kernel="linear", random_state=RANDOM_SEED)),
classification_binary(light_clf.KernelSVC(
kernel="poly", degree=2, random_state=RANDOM_SEED)),
classification_binary(light_clf.KernelSVC(
kernel="cosine", random_state=RANDOM_SEED)),
classification_binary(light_clf.KernelSVC(
kernel="sigmoid", random_state=RANDOM_SEED)),
# Sklearn Linear Regression
regression(linear_model.ARDRegression()),
regression(linear_model.BayesianRidge()),
Expand Down Expand Up @@ -250,7 +271,18 @@ def classification_binary_random(model):
len(utils.get_regression_model_trainer().y_train))),
fit_regularized=STATSMODELS_LINEAR_REGULARIZED_PARAMS))),
# Linear Classifiers
# lightning Linear Regression
regression(light_reg.AdaGradRegressor(random_state=RANDOM_SEED)),
regression(light_reg.CDRegressor(random_state=RANDOM_SEED)),
regression(light_reg.FistaRegressor()),
regression(light_reg.SAGARegressor(random_state=RANDOM_SEED)),
regression(light_reg.SAGRegressor(random_state=RANDOM_SEED)),
regression(light_reg.SDCARegressor(random_state=RANDOM_SEED)),
# default loss results in nan coefs
regression(light_reg.SGDRegressor(
loss="huber", random_state=RANDOM_SEED)),
# Sklearn Linear Classifiers
classification(linear_model.LogisticRegression(
random_state=RANDOM_SEED)),
classification(linear_model.LogisticRegressionCV(
Expand All @@ -259,9 +291,11 @@ def classification_binary_random(model):
random_state=RANDOM_SEED)),
classification(linear_model.Perceptron(
random_state=RANDOM_SEED)),
classification(linear_model.RidgeClassifier(random_state=RANDOM_SEED)),
classification(linear_model.RidgeClassifier(
random_state=RANDOM_SEED)),
classification(linear_model.RidgeClassifierCV()),
classification(linear_model.SGDClassifier(random_state=RANDOM_SEED)),
classification(linear_model.SGDClassifier(
random_state=RANDOM_SEED)),
classification_binary(linear_model.LogisticRegression(
random_state=RANDOM_SEED)),
Expand All @@ -277,6 +311,39 @@ def classification_binary_random(model):
classification_binary(linear_model.SGDClassifier(
random_state=RANDOM_SEED)),
# lightning Linear Classifiers
classification(light_clf.AdaGradClassifier(
random_state=RANDOM_SEED)),
classification(light_clf.CDClassifier(
random_state=RANDOM_SEED)),
classification(light_clf.FistaClassifier()),
classification(light_clf.SDCAClassifier(
random_state=RANDOM_SEED)),
classification(light_clf.SAGClassifier(
random_state=RANDOM_SEED)),
classification(light_clf.SAGAClassifier(
random_state=RANDOM_SEED)),
classification(light_clf.SGDClassifier(
random_state=RANDOM_SEED)),
classification(light_clf.SVRGClassifier(
random_state=RANDOM_SEED)),
classification_binary(light_clf.AdaGradClassifier(
random_state=RANDOM_SEED)),
classification_binary(light_clf.CDClassifier(
random_state=RANDOM_SEED)),
classification_binary(light_clf.FistaClassifier()),
classification_binary(light_clf.SDCAClassifier(
random_state=RANDOM_SEED)),
classification_binary(light_clf.SAGClassifier(
random_state=RANDOM_SEED)),
classification_binary(light_clf.SAGAClassifier(
random_state=RANDOM_SEED)),
classification_binary(light_clf.SGDClassifier(
random_state=RANDOM_SEED)),
classification_binary(light_clf.SVRGClassifier(
random_state=RANDOM_SEED)),
# Decision trees
regression(tree.DecisionTreeRegressor(**TREE_PARAMS)),
regression(tree.ExtraTreeRegressor(**TREE_PARAMS)),
Expand Down

0 comments on commit 74932bd

Please sign in to comment.