Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Mar 14, 2020
1 parent a8d0def commit 49bba30
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pip install m2cgen
| | Classification | Regression |
| --- | --- | --- |
| **Linear** | <ul><li>scikit-learn<ul><li>LogisticRegression</li><li>LogisticRegressionCV</li><li>PassiveAggressiveClassifier</li><li>Perceptron</li><li>RidgeClassifier</li><li>RidgeClassifierCV</li><li>SGDClassifier</li></ul></li><li>lightning<ul><li>AdaGradClassifier</li><li>CDClassifier</li><li>FistaClassifier</li><li>SAGAClassifier</li><li>SAGClassifier</li><li>SDCAClassifier</li><li>SGDClassifier</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>ARDRegression</li><li>BayesianRidge</li><li>ElasticNet</li><li>ElasticNetCV</li><li>HuberRegressor</li><li>Lars</li><li>LarsCV</li><li>Lasso</li><li>LassoCV</li><li>LassoLars</li><li>LassoLarsCV</li><li>LassoLarsIC</li><li>LinearRegression</li><li>OrthogonalMatchingPursuit</li><li>OrthogonalMatchingPursuitCV</li><li>PassiveAggressiveRegressor</li><li>RANSACRegressor(only supported regression estimators can be used as a base estimator)</li><li>Ridge</li><li>RidgeCV</li><li>SGDRegressor</li><li>TheilSenRegressor</li></ul><li>StatsModels<ul><li>Generalized Least Squares (GLS)</li><li>Generalized Least Squares with AR Errors (GLSAR)</li><li>Ordinary Least Squares (OLS)</li><li>Quantile Regression (QuantReg)</li><li>Weighted Least Squares (WLS)</li></ul><li>lightning<ul><li>AdaGradRegressor</li><li>CDRegressor</li><li>FistaRegressor</li><li>SAGARegressor</li><li>SAGRegressor</li><li>SDCARegressor</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>KernelSVC</li><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>KernelSVC (binary only, multiclass is not supported yet)</li><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier(binary only, multiclass is not supported yet)</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
| **Boosting** | <ul><li>LGBMClassifier(gbdt/dart/goss booster only)</li><li>XGBClassifier(gbtree/gblinear booster only)</li><ul> | <ul><li>LGBMRegressor(gbdt/dart/goss booster only)</li><li>XGBRegressor(gbtree/gblinear booster only)</li></ul> |
Expand Down
15 changes: 6 additions & 9 deletions m2cgen/assemblers/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def _apply_kernel(self, support_vectors, to_reuse=False):
return kernel_exprs

def _get_supported_kernels(self):
return {}
return {
"rbf": self._rbf_kernel,
"sigmoid": self._sigmoid_kernel,
"poly": self._poly_kernel,
"linear": self._linear_kernel
}

def _get_gamma(self):
raise NotImplementedError
Expand All @@ -74,14 +79,6 @@ def _get_single_intercept(self):

class SklearnSVMModelAssembler(BaseSVMModelAssembler):

def _get_supported_kernels(self):
return {
"rbf": self._rbf_kernel,
"sigmoid": self._sigmoid_kernel,
"poly": self._poly_kernel,
"linear": self._linear_kernel
}

def _get_gamma(self):
return self.model._gamma

Expand Down
3 changes: 2 additions & 1 deletion tests/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_count_all_exprs_types():
ast.BinVectorExpr(
ast.VectorVal([
ast.ExpExpr(ast.NumVal(2)),
ast.SqrtExpr(ast.NumVal(2)),
ast.PowExpr(ast.NumVal(2), ast.NumVal(3)),
ast.TanhExpr(ast.NumVal(1)),
ast.BinNumExpr(
Expand All @@ -67,4 +68,4 @@ def test_count_all_exprs_types():
)),
ast.BinNumOpType.MUL)

assert ast.count_exprs(expr) == 24
assert ast.count_exprs(expr) == 25

0 comments on commit 49bba30

Please sign in to comment.