Skip to content

Commit

Permalink
Add support for Perceptron and tests for PassiveAggressiveClassifier (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Feb 10, 2020
1 parent 81c3e6a commit f52dc13
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -41,7 +41,7 @@ pip install m2cgen

| | Classification | Regression |
| --- | --- | --- |
| **Linear** | <ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul> | <ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul> |
| **Linear** | <ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul> | <ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul> |
| **SVM** | <ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul> | <ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
Expand Down
3 changes: 2 additions & 1 deletion m2cgen/assemblers/__init__.py
Expand Up @@ -60,10 +60,11 @@
"SGDRegressor": LinearModelAssembler,
"TheilSenRegressor": LinearModelAssembler,

# Logistic Regressors
# Linear Classifiers
"LogisticRegression": LinearModelAssembler,
"LogisticRegressionCV": LinearModelAssembler,
"PassiveAggressiveClassifier": LinearModelAssembler,
"Perceptron": LinearModelAssembler,
"RidgeClassifier": LinearModelAssembler,
"RidgeClassifierCV": LinearModelAssembler,
"SGDClassifier": LinearModelAssembler,
Expand Down
10 changes: 9 additions & 1 deletion tests/e2e/test_e2e.py
Expand Up @@ -218,11 +218,15 @@ def classification_binary_random(model):
regression(linear_model.SGDRegressor(random_state=RANDOM_SEED)),
regression(linear_model.TheilSenRegressor(random_state=RANDOM_SEED)),
# Logistic Regression
# Linear Classifiers
classification(linear_model.LogisticRegression(
random_state=RANDOM_SEED)),
classification(linear_model.LogisticRegressionCV(
random_state=RANDOM_SEED)),
classification(linear_model.PassiveAggressiveClassifier(
random_state=RANDOM_SEED)),
classification(linear_model.Perceptron(
random_state=RANDOM_SEED)),
classification(linear_model.RidgeClassifier(random_state=RANDOM_SEED)),
classification(linear_model.RidgeClassifierCV()),
classification(linear_model.SGDClassifier(random_state=RANDOM_SEED)),
Expand All @@ -231,6 +235,10 @@ def classification_binary_random(model):
random_state=RANDOM_SEED)),
classification_binary(linear_model.LogisticRegressionCV(
random_state=RANDOM_SEED)),
classification_binary(linear_model.PassiveAggressiveClassifier(
random_state=RANDOM_SEED)),
classification_binary(linear_model.Perceptron(
random_state=RANDOM_SEED)),
classification_binary(linear_model.RidgeClassifier(
random_state=RANDOM_SEED)),
classification_binary(linear_model.RidgeClassifierCV()),
Expand Down

0 comments on commit f52dc13

Please sign in to comment.