Skip to content

Commit

Permalink
Add support for RANSACRegressor (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Feb 13, 2020
1 parent e51b65c commit d28cb00
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pip install m2cgen

| | Classification | Regression |
| --- | --- | --- |
| **Linear** | <ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul> | <ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul> |
| **Linear** | <ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul> | <ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul> |
| **SVM** | <ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul> | <ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
Expand Down
3 changes: 3 additions & 0 deletions m2cgen/assemblers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
XGBoostLinearModelAssembler,
LightGBMModelAssembler)
from .svm import SVMModelAssembler
from .meta import RANSACModelAssembler

__all__ = [
LinearModelAssembler,
RANSACModelAssembler,
TreeModelAssembler,
RandomForestModelAssembler,
XGBoostModelAssemblerSelector,
Expand Down Expand Up @@ -55,6 +57,7 @@
"OrthogonalMatchingPursuit": LinearModelAssembler,
"OrthogonalMatchingPursuitCV": LinearModelAssembler,
"PassiveAggressiveRegressor": LinearModelAssembler,
"RANSACRegressor": RANSACModelAssembler,
"Ridge": LinearModelAssembler,
"RidgeCV": LinearModelAssembler,
"SGDRegressor": LinearModelAssembler,
Expand Down
19 changes: 19 additions & 0 deletions m2cgen/assemblers/meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from m2cgen.assemblers.base import ModelAssembler


class BaseMetaAssembler(ModelAssembler):

def assemble(self):
# import here to avoid circular import error
from m2cgen.assemblers import get_assembler_cls
base_model = self._get_base_model()
return get_assembler_cls(base_model)(base_model).assemble()

def _get_base_model(self):
raise NotImplementedError


class RANSACModelAssembler(BaseMetaAssembler):

def _get_base_model(self):
return self.model.estimator_
36 changes: 36 additions & 0 deletions tests/assemblers/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
import numpy as np
from sklearn import linear_model
from sklearn.dummy import DummyRegressor
from sklearn.tree import DecisionTreeRegressor

from m2cgen import assemblers, ast
from tests import utils
Expand Down Expand Up @@ -127,3 +130,36 @@ def test_binary_class():
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)


def test_ransac_custom_base_estimator():
base_estimator = DecisionTreeRegressor()
estimator = linear_model.RANSACRegressor(
base_estimator=base_estimator,
random_state=1)
estimator.fit([[1], [2], [3]], [1, 2, 3])

assembler = assemblers.RANSACModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.NumVal(2.0),
ast.NumVal(3.0))

assert utils.cmp_exprs(actual, expected)


@pytest.mark.xfail(raises=NotImplementedError, strict=True)
def test_ransac_unknown_base_estimator():
base_estimator = DummyRegressor()
estimator = linear_model.RANSACRegressor(
base_estimator=base_estimator,
random_state=1)
estimator.fit([[1], [2], [3]], [1, 2, 3])

assembler = assemblers.RANSACModelAssembler(estimator)
assembler.assemble()
3 changes: 3 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def classification_binary_random(model):
regression(linear_model.OrthogonalMatchingPursuitCV()),
regression(linear_model.PassiveAggressiveRegressor(
random_state=RANDOM_SEED)),
regression(linear_model.RANSACRegressor(
base_estimator=tree.ExtraTreeRegressor(**TREE_PARAMS),
random_state=RANDOM_SEED)),
regression(linear_model.Ridge(random_state=RANDOM_SEED)),
regression(linear_model.RidgeCV()),
regression(linear_model.SGDRegressor(random_state=RANDOM_SEED)),
Expand Down

0 comments on commit d28cb00

Please sign in to comment.