Skip to content

Commit

Permalink
Merge f8a597c into ba900a5
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Apr 24, 2020
2 parents ba900a5 + f8a597c commit 9c0bff6
Show file tree
Hide file tree
Showing 9 changed files with 567 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -43,7 +43,7 @@ pip install m2cgen

| | Classification | Regression |
| --- | --- | --- |
| **Linear** | <ul><li>scikit-learn<ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul></li><li>lightning<ul><li>AdaGradClassifier</li><li>CDClassifier</li><li>FistaClassifier</li><li>SAGAClassifier</li><li>SAGClassifier</li><li>SDCAClassifier</li><li>SGDClassifier</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Generalized Least Squares with AR Errors (GLSAR)</li><li>Ordinary Least Squares (OLS)</li><li>[Gaussian] Process Regression Using Maximum Likelihood-based Estimation (ProcessMLE)</li><li>Quantile Regression (QuantReg)</li><li>Weighted Least Squares (WLS)</li></ul><li>lightning<ul><li>AdaGradRegressor</li><li>CDRegressor</li><li>FistaRegressor</li><li>SAGARegressor</li><li>SAGRegressor</li><li>SDCARegressor</li></ul></li></ul> |
| **Linear** | <ul><li>scikit-learn<ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul></li><li>lightning<ul><li>AdaGradClassifier</li><li>CDClassifier</li><li>FistaClassifier</li><li>SAGAClassifier</li><li>SAGClassifier</li><li>SDCAClassifier</li><li>SGDClassifier</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Generalized Least Squares with AR Errors (GLSAR)</li><li>Generalized Linear Models (GLM)</li><li>Ordinary Least Squares (OLS)</li><li>[Gaussian] Process Regression Using Maximum Likelihood-based Estimation (ProcessMLE)</li><li>Quantile Regression (QuantReg)</li><li>Weighted Least Squares (WLS)</li></ul><li>lightning<ul><li>AdaGradRegressor</li><li>CDRegressor</li><li>FistaRegressor</li><li>SAGARegressor</li><li>SAGRegressor</li><li>SDCARegressor</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>KernelSVC</li><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
Expand Down
11 changes: 8 additions & 3 deletions m2cgen/assemblers/__init__.py
@@ -1,6 +1,8 @@
from .linear import (SklearnLinearModelAssembler,
StatsmodelsLinearModelAssembler,
ProcessMLEModelAssembler)
ProcessMLEModelAssembler,
StatsmodelsGLMModelAssembler,
StatsmodelsModelAssemblerSelector)
from .tree import TreeModelAssembler
from .ensemble import RandomForestModelAssembler
from .boosting import (XGBoostModelAssemblerSelector,
Expand All @@ -23,6 +25,8 @@
LightGBMModelAssembler,
SklearnSVMModelAssembler,
LightningSVMModelAssembler,
StatsmodelsGLMModelAssembler,
StatsmodelsModelAssemblerSelector,
]


Expand Down Expand Up @@ -74,9 +78,10 @@
"sklearn_TheilSenRegressor": SklearnLinearModelAssembler,

# Statsmodels Linear Regressors
"statsmodels_GLMResultsWrapper": StatsmodelsGLMModelAssembler,
"statsmodels_ProcessMLEResults": ProcessMLEModelAssembler,
"statsmodels_RegressionResultsWrapper": StatsmodelsLinearModelAssembler,
"statsmodels_RegularizedResultsWrapper": StatsmodelsLinearModelAssembler,
"statsmodels_RegularizedResultsWrapper": StatsmodelsModelAssemblerSelector,

# Lightning Linear Regressors
"lightning_AdaGradRegressor": SklearnLinearModelAssembler,
Expand Down Expand Up @@ -130,6 +135,6 @@ def get_assembler_cls(model):

if not assembler_cls:
raise NotImplementedError(
"Model {} is not supported".format(model_name))
"Model '{}' is not supported".format(model_name))

return assembler_cls
114 changes: 112 additions & 2 deletions m2cgen/assemblers/linear.py
Expand Up @@ -15,13 +15,18 @@ def _build_ast(self):
intercept = utils.to_1d_array(self._get_intercept())

if coef.shape[0] == 1:
return _linear_to_ast(coef[0], intercept[0])
return self._final_transform(
_linear_to_ast(coef[0], intercept[0]))

exprs = []
for idx in range(coef.shape[0]):
exprs.append(_linear_to_ast(coef[idx], intercept[idx]))
exprs.append(self._final_transform(
_linear_to_ast(coef[idx], intercept[idx])))
return ast.VectorVal(exprs)

def _final_transform(self, ast_to_transform):
return ast_to_transform

def _get_intercept(self):
raise NotImplementedError

Expand Down Expand Up @@ -70,6 +75,111 @@ def _get_coef(self):
return self.model.params[:self.model.k_exog]


class StatsmodelsGLMModelAssembler(StatsmodelsLinearModelAssembler):

def _final_transform(self, ast_to_transform):
link_function = type(self.model.model.family.link).__name__
link_function_lower = link_function.lower()
supported_functions = {
"logit": self._logit,
"power": self._power,
"inverse_power": self._inverse_power,
"sqrt": self._sqrt,
"inverse_squared": self._inverse_squared,
"identity": self._identity,
"log": self._log,
"cloglog": self._cloglog,
"negativebinomial": self._negativebinomial,
"nbinom": self._negativebinomial
}
if link_function_lower not in supported_functions:
raise ValueError(
"Unsupported link function '{}'".format(link_function))
link_fun = supported_functions[link_function_lower]
return link_fun(ast_to_transform)

def _logit(self, ast_to_transform):
return utils.div(
ast.NumVal(1.0),
utils.add(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform))))

def _power(self, ast_to_transform):
power = self.model.model.family.link.power
if power == 1:
return self._identity(ast_to_transform)
elif power == -1:
return self._inverse_power(ast_to_transform)
elif power == 2:
return ast.SqrtExpr(ast_to_transform)
elif power == -2:
return self._inverse_squared(ast_to_transform)
elif power < 0: # some languages may not support negative exponent
return utils.div(
ast.NumVal(1.0),
ast.PowExpr(ast_to_transform, ast.NumVal(1 / -power)))
else:
return ast.PowExpr(ast_to_transform, ast.NumVal(1 / power))

def _inverse_power(self, ast_to_transform):
return utils.div(ast.NumVal(1.0), ast_to_transform)

def _sqrt(self, ast_to_transform):
return ast.PowExpr(ast_to_transform, ast.NumVal(2))

def _inverse_squared(self, ast_to_transform):
return utils.div(ast.NumVal(1.0), ast.SqrtExpr(ast_to_transform))

def _identity(self, ast_to_transform):
return ast_to_transform

def _log(self, ast_to_transform):
return ast.ExpExpr(ast_to_transform)

def _cloglog(self, ast_to_transform):
return utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast.ExpExpr(ast_to_transform))))

def _negativebinomial(self, ast_to_transform):
return utils.div(
ast.NumVal(-1.0),
utils.mul(
ast.NumVal(self.model.model.family.link.alpha),
utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform)))))


class StatsmodelsModelAssemblerSelector(ModelAssembler):

def __init__(self, model):
underlying_model = type(model.model).__name__
if underlying_model == "GLM":
self.assembler = StatsmodelsGLMModelAssembler(model)
elif underlying_model in {"GLS",
"GLSAR",
"OLS",
"WLS"}:
self.assembler = StatsmodelsLinearModelAssembler(model)
else:
raise NotImplementedError(
"Model '{}' is not supported".format(underlying_model))

def assemble(self):
return self.assembler.assemble()


def _linear_to_ast(coef, intercept):
feature_weight_mul_ops = []

Expand Down
3 changes: 2 additions & 1 deletion m2cgen/assemblers/svm.py
Expand Up @@ -13,7 +13,8 @@ def __init__(self, model):
kernel_type = model.kernel
supported_kernels = self._get_supported_kernels()
if kernel_type not in supported_kernels:
raise ValueError("Unsupported kernel type {}".format(kernel_type))
raise ValueError(
"Unsupported kernel type '{}'".format(kernel_type))
self._kernel_fun = supported_kernels[kernel_type]

gamma = self._get_gamma()
Expand Down
4 changes: 2 additions & 2 deletions m2cgen/interpreters/mixins.py
Expand Up @@ -61,7 +61,7 @@ class LinearAlgebraMixin(BaseToCodeInterpreter):
def interpret_bin_vector_expr(self, expr, extra_func_args=(), **kwargs):
if expr.op not in self.supported_bin_vector_ops:
raise NotImplementedError(
"Op {} is unsupported".format(expr.op.name))
"Op '{}' is unsupported".format(expr.op.name))

self.with_linear_algebra = True

Expand All @@ -77,7 +77,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(),
**kwargs):
if expr.op not in self.supported_bin_vector_num_ops:
raise NotImplementedError(
"Op {} is unsupported".format(expr.op.name))
"Op '{}' is unsupported".format(expr.op.name))

self.with_linear_algebra = True

Expand Down

0 comments on commit 9c0bff6

Please sign in to comment.