Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added lightning.KernelSVC (binary only) #176

Merged
merged 6 commits into from Mar 16, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -43,7 +43,7 @@ pip install m2cgen
| | Classification | Regression |
| --- | --- | --- |
| **Linear** | <ul><li>scikit-learn<ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul></li><li>lightning<ul><li>AdaGradClassifier</li><li>CDClassifier</li><li>FistaClassifier</li><li>SAGAClassifier</li><li>SAGClassifier</li><li>SDCAClassifier</li><li>SGDClassifier</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Generalized Least Squares with AR Errors (GLSAR)</li><li>Ordinary Least Squares (OLS)</li><li>Quantile Regression (QuantReg)</li><li>Weighted Least Squares (WLS)</li></ul><li>lightning<ul><li>AdaGradRegressor</li><li>CDRegressor</li><li>FistaRegressor</li><li>SAGARegressor</li><li>SAGRegressor</li><li>SDCARegressor</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>KernelSVC</li><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
| **Boosting** | <ul><li>LGBMClassifier(gbdt/dart/goss booster only)</li><li>XGBClassifier(gbtree/gblinear booster only)</li><ul> | <ul><li>LGBMRegressor(gbdt/dart/goss booster only)</li><li>XGBRegressor(gbtree/gblinear booster only)</li></ul> |
Expand Down
14 changes: 8 additions & 6 deletions m2cgen/assemblers/__init__.py
Expand Up @@ -6,7 +6,7 @@
XGBoostTreeModelAssembler,
XGBoostLinearModelAssembler,
LightGBMModelAssembler)
from .svm import SVMModelAssembler
from .svm import SklearnSVMModelAssembler, LightningSVMModelAssembler
from .meta import RANSACModelAssembler

__all__ = [
Expand All @@ -19,7 +19,8 @@
XGBoostTreeModelAssembler,
XGBoostLinearModelAssembler,
LightGBMModelAssembler,
SVMModelAssembler,
SklearnSVMModelAssembler,
LightningSVMModelAssembler,
]


Expand All @@ -37,12 +38,13 @@
# Sklearn SVM
"sklearn_LinearSVC": SklearnLinearModelAssembler,
"sklearn_LinearSVR": SklearnLinearModelAssembler,
"sklearn_NuSVC": SVMModelAssembler,
"sklearn_NuSVR": SVMModelAssembler,
"sklearn_SVC": SVMModelAssembler,
"sklearn_SVR": SVMModelAssembler,
"sklearn_NuSVC": SklearnSVMModelAssembler,
"sklearn_NuSVR": SklearnSVMModelAssembler,
"sklearn_SVC": SklearnSVMModelAssembler,
"sklearn_SVR": SklearnSVMModelAssembler,

# Lightning SVM
"lightning_KernelSVC": LightningSVMModelAssembler,
"lightning_LinearSVC": SklearnLinearModelAssembler,
"lightning_LinearSVR": SklearnLinearModelAssembler,

Expand Down
1 change: 1 addition & 0 deletions m2cgen/assemblers/boosting.py
@@ -1,5 +1,6 @@
import json
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers.base import ModelAssembler
Expand Down
118 changes: 93 additions & 25 deletions m2cgen/assemblers/svm.py
@@ -1,38 +1,27 @@
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers.base import ModelAssembler


class SVMModelAssembler(ModelAssembler):
class BaseSVMModelAssembler(ModelAssembler):

def __init__(self, model):
super().__init__(model)

supported_kernels = {
"rbf": self._rbf_kernel,
"sigmoid": self._sigmoid_kernel,
"poly": self._poly_kernel,
"linear": self._linear_kernel
}
kernel_type = model.kernel
supported_kernels = self._get_supported_kernels()
if kernel_type not in supported_kernels:
raise ValueError("Unsupported kernel type {}".format(kernel_type))
self._kernel_fun = supported_kernels[kernel_type]

n_features = len(model.support_vectors_[0])

gamma = model.gamma
if gamma == "auto" or gamma == "auto_deprecated":
gamma = 1.0 / n_features
gamma = self._get_gamma()
self._gamma_expr = ast.NumVal(gamma)
self._neg_gamma_expr = utils.sub(ast.NumVal(0), ast.NumVal(gamma),
to_reuse=True)

self._output_size = 1
if type(model).__name__ in ("SVC", "NuSVC"):
n_classes = len(model.n_support_)
if n_classes > 2:
self._output_size = n_classes
self._output_size = self._get_output_size()

def assemble(self):
if self._output_size > 1:
Expand All @@ -42,8 +31,8 @@ def assemble(self):

def _assemble_single_output(self):
support_vectors = self.model.support_vectors_
coef = self.model.dual_coef_[0]
intercept = self.model.intercept_[0]
coef = self._get_single_coef()
intercept = self._get_single_intercept()

kernel_exprs = self._apply_kernel(support_vectors)

Expand All @@ -57,6 +46,53 @@ def _assemble_single_output(self):
ast.NumVal(intercept),
*kernel_weight_mul_ops)

def _apply_kernel(self, support_vectors, to_reuse=False):
kernel_exprs = []
for v in support_vectors:
kernel = self._kernel_fun(v)
kernel_exprs.append(ast.SubroutineExpr(kernel, to_reuse=to_reuse))
return kernel_exprs

def _get_supported_kernels(self):
return {}

def _get_gamma(self):
raise NotImplementedError

def _get_output_size(self):
raise NotImplementedError

def _assemble_multi_class_output(self):
raise NotImplementedError

def _get_single_coef(self):
raise NotImplementedError

def _get_single_intercept(self):
raise NotImplementedError


class SklearnSVMModelAssembler(BaseSVMModelAssembler):

def _get_supported_kernels(self):
return {
"rbf": self._rbf_kernel,
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
"sigmoid": self._sigmoid_kernel,
"poly": self._poly_kernel,
"linear": self._linear_kernel
}

def _get_gamma(self):
return self.model._gamma

def _get_output_size(self):
output_size = 1
if type(self.model).__name__ in {"SVC", "NuSVC"}:
n_classes = len(self.model.n_support_)
if n_classes > 2:
output_size = n_classes
return output_size

def _assemble_multi_class_output(self):
support_vectors = self.model.support_vectors_
coef = self.model.dual_coef_
Expand Down Expand Up @@ -96,12 +132,11 @@ def _assemble_multi_class_output(self):

return ast.VectorVal(decisions)

def _apply_kernel(self, support_vectors, to_reuse=False):
kernel_exprs = []
for v in support_vectors:
kernel = self._kernel_fun(v)
kernel_exprs.append(ast.SubroutineExpr(kernel, to_reuse=to_reuse))
return kernel_exprs
def _get_single_coef(self):
return self.model.dual_coef_[0]

def _get_single_intercept(self):
return self.model.intercept_[0]

def _rbf_kernel(self, support_vector):
elem_wise = [
Expand Down Expand Up @@ -135,3 +170,36 @@ def _linear_kernel_with_gama_and_coef(self, support_vector):
kernel = self._linear_kernel(support_vector)
kernel = utils.mul(self._gamma_expr, kernel)
return utils.add(kernel, ast.NumVal(self.model.coef0))


class LightningSVMModelAssembler(SklearnSVMModelAssembler):

def _get_supported_kernels(self):
kernels = super()._get_supported_kernels()
kernels["cosine"] = self._cosine_kernel
return kernels

def _get_gamma(self):
return self.model.gamma

def _get_output_size(self):
return 1

def _assemble_multi_class_output(self):
raise NotImplementedError

def _get_single_coef(self):
return self.model.coef_[0]

def _cosine_kernel(self, support_vector):
support_vector_norm = np.linalg.norm(support_vector)
if support_vector_norm == 0.0:
support_vector_norm = 1.0
feature_norm = ast.SqrtExpr(
utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
*[utils.mul(ast.FeatureRef(i), ast.FeatureRef(i))
for i in range(len(support_vector))]))
kernel = self._linear_kernel(support_vector / support_vector_norm)
kernel = utils.div(kernel, feature_norm)
return kernel
5 changes: 5 additions & 0 deletions m2cgen/assemblers/utils.py
@@ -1,11 +1,16 @@
import numpy as np

from m2cgen import ast


def mul(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.MUL, to_reuse=to_reuse)


def div(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.DIV, to_reuse=to_reuse)


def add(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.ADD, to_reuse=to_reuse)

Expand Down
14 changes: 13 additions & 1 deletion m2cgen/ast.py
Expand Up @@ -49,6 +49,18 @@ def __str__(self):
return "ExpExpr(" + args + ")"


class SqrtExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"

self.expr = expr
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "SqrtExpr(" + args + ")"


class TanhExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -237,7 +249,7 @@ def __str__(self):
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((ExpExpr, TanhExpr, TransparentExpr), lambda e: [e.expr]),
((ExpExpr, SqrtExpr, TanhExpr, TransparentExpr), lambda e: [e.expr]),
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/c/interpreter.py
Expand Up @@ -19,6 +19,7 @@ class CInterpreter(ToCodeInterpreter,

exponent_function_name = "exp"
power_function_name = "pow"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/c_sharp/interpreter.py
Expand Up @@ -19,6 +19,7 @@ class CSharpInterpreter(ToCodeInterpreter, mixins.LinearAlgebraMixin):

exponent_function_name = "Exp"
power_function_name = "Pow"
sqrt_function_name = "Sqrt"
tanh_function_name = "Tanh"

def __init__(self, namespace="ML", class_name="Model", indent=4,
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/dart/interpreter.py
Expand Up @@ -23,6 +23,7 @@ class DartInterpreter(ToCodeInterpreter,

exponent_function_name = "exp"
power_function_name = "pow"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_tanh_expr = False
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/go/interpreter.py
Expand Up @@ -18,6 +18,7 @@ class GoInterpreter(ToCodeInterpreter,

exponent_function_name = "math.Exp"
power_function_name = "math.Pow"
sqrt_function_name = "math.Sqrt"
tanh_function_name = "math.Tanh"

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down
8 changes: 8 additions & 0 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -98,6 +98,7 @@ class ToCodeInterpreter(BaseToCodeInterpreter):

exponent_function_name = NotImplemented
power_function_name = NotImplemented
sqrt_function_name = NotImplemented
tanh_function_name = NotImplemented

def __init__(self, cg, feature_array_name="input"):
Expand Down Expand Up @@ -160,6 +161,13 @@ def interpret_exp_expr(self, expr, **kwargs):
return self._cg.function_invocation(
self.exponent_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
assert self.sqrt_function_name, "Sqrt function is not provided"
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.sqrt_function_name, nested_result)

def interpret_tanh_expr(self, expr, **kwargs):
assert self.tanh_function_name, "Tanh function is not provided"
self.with_math_module = True
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/java/interpreter.py
Expand Up @@ -28,6 +28,7 @@ class JavaInterpreter(ToCodeInterpreter,

exponent_function_name = "Math.exp"
power_function_name = "Math.pow"
sqrt_function_name = "Math.sqrt"
tanh_function_name = "Math.tanh"

def __init__(self, package_name=None, class_name="Model", indent=4,
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/javascript/interpreter.py
Expand Up @@ -21,6 +21,7 @@ class JavascriptInterpreter(ToCodeInterpreter,

exponent_function_name = "Math.exp"
power_function_name = "Math.pow"
sqrt_function_name = "Math.sqrt"
tanh_function_name = "Math.tanh"

def __init__(self, indent=4, function_name="score",
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/php/interpreter.py
Expand Up @@ -18,6 +18,7 @@ class PhpInterpreter(ToCodeInterpreter, mixins.LinearAlgebraMixin):

exponent_function_name = "exp"
power_function_name = "pow"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/powershell/interpreter.py
Expand Up @@ -20,6 +20,7 @@ class PowershellInterpreter(ToCodeInterpreter,

exponent_function_name = "[math]::Exp"
power_function_name = "[math]::Pow"
sqrt_function_name = "[math]::Sqrt"
tanh_function_name = "[math]::Tanh"

def __init__(self, indent=4, function_name="Score", *args, **kwargs):
Expand Down Expand Up @@ -52,6 +53,11 @@ def interpret_exp_expr(self, expr, **kwargs):
return self._cg.math_function_invocation(
self.exponent_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.math_function_invocation(
self.sqrt_function_name, nested_result)

def interpret_tanh_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.math_function_invocation(
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/python/interpreter.py
Expand Up @@ -12,6 +12,7 @@ class PythonInterpreter(ToCodeInterpreter,

exponent_function_name = "math.exp"
power_function_name = "math.pow"
sqrt_function_name = "math.sqrt"
tanh_function_name = "math.tanh"

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/r/interpreter.py
Expand Up @@ -20,6 +20,7 @@ class RInterpreter(ToCodeInterpreter,
bin_depth_threshold = 25

exponent_function_name = "exp"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down
7 changes: 7 additions & 0 deletions m2cgen/interpreters/visual_basic/interpreter.py
Expand Up @@ -67,6 +67,13 @@ def interpret_pow_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=base_result, right=exp_result, op="^")

def interpret_sqrt_expr(self, expr, **kwargs):
return self.interpret_pow_expr(
ast.PowExpr(base_expr=expr.expr,
exp_expr=ast.NumVal(0.5),
to_reuse=expr.to_reuse),
**kwargs)

def interpret_tanh_expr(self, expr, **kwargs):
self.with_tanh_expr = True
return super(
Expand Down