Skip to content

Commit

Permalink
Merge branch 'master' into github_actions
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jul 7, 2020
2 parents 5ac6d80 + 2be5875 commit 9505091
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 25 deletions.
2 changes: 1 addition & 1 deletion m2cgen/interpreters/c/code_generator.py
Expand Up @@ -52,7 +52,7 @@ def add_assign_array_statement(self, source_var, target_var, size):
f"{size} * sizeof(double));")

def add_dependency(self, dep):
super().prepend_code_line(f"#include {dep}")
self.prepend_code_line(f"#include {dep}")

def vector_init(self, values):
return f"(double[]){{{', '.join(values)}}}"
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/interpreters/go/code_generator.py
Expand Up @@ -38,7 +38,7 @@ def _get_var_declare_type(self, is_vector):
return self.vector_type if is_vector else self.scalar_type

def add_dependency(self, dep):
super().prepend_code_line(f'import "{dep}"')
self.prepend_code_line(f'import "{dep}"')

def vector_init(self, values):
return f"[]float64{{{', '.join(values)}}}"
10 changes: 5 additions & 5 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -124,20 +124,20 @@ def interpret_vector_val(self, expr, **kwargs):
return self._cg.vector_init(nested)

def interpret_abs_expr(self, expr, **kwargs):
self.with_math_module = True
if self.abs_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.abs(expr.expr), **kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.abs_function_name, nested_result)

def interpret_exp_expr(self, expr, **kwargs):
self.with_math_module = True
if self.exponent_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.exp(expr.expr, to_reuse=expr.to_reuse),
**kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.exponent_function_name, nested_result)
Expand All @@ -151,29 +151,29 @@ def interpret_log_expr(self, expr, **kwargs):
self.logarithm_function_name, nested_result)

def interpret_log1p_expr(self, expr, **kwargs):
self.with_math_module = True
if self.log1p_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.log1p(expr.expr), **kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.log1p_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
self.with_math_module = True
if self.sqrt_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.sqrt(expr.expr, to_reuse=expr.to_reuse),
**kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.sqrt_function_name, nested_result)

def interpret_tanh_expr(self, expr, **kwargs):
self.with_math_module = True
if self.tanh_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.tanh(expr.expr), **kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.tanh_function_name, nested_result)
Expand Down
4 changes: 2 additions & 2 deletions requirements-test.txt
Expand Up @@ -9,10 +9,10 @@ git+git://github.com/scikit-learn-contrib/lightning.git@782c18c12961e509099ae84c
flake8==3.8.3
pytest==5.4.3
pytest-mock==3.1.1
coveralls==2.0.0
coveralls==2.1.0
pytest-cov==2.10.0

# Other stuff
numpy==1.18.5
scipy==1.5.0
scipy==1.5.1
py-mini-racer==0.3.0
30 changes: 15 additions & 15 deletions tests/assemblers/test_linear.py
Expand Up @@ -135,7 +135,7 @@ def test_statsmodels_wo_const():
estimator = utils.StatsmodelsSklearnLikeWrapper(sm.GLS, {})
_, __, estimator = utils.get_regression_model_trainer()(estimator)

assembler = assemblers.StatsmodelsLinearModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

feature_weight_mul = [
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_statsmodels_w_const():
dict(init=dict(fit_intercept=True)))
_, __, estimator = utils.get_regression_model_trainer()(estimator)

assembler = assemblers.StatsmodelsLinearModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

feature_weight_mul = [
Expand Down Expand Up @@ -280,7 +280,7 @@ def test_statsmodels_unknown_constant_position():
dict(init=dict(hasconst=True)))
_, __, estimator = utils.get_regression_model_trainer()(estimator)

assembler = assemblers.StatsmodelsLinearModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
assembler.assemble()


Expand Down Expand Up @@ -377,7 +377,7 @@ def test_statsmodels_glm_logit_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.BinNumExpr(
Expand Down Expand Up @@ -410,7 +410,7 @@ def test_statsmodels_glm_power_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.PowExpr(
Expand All @@ -435,7 +435,7 @@ def test_statsmodels_glm_negative_power_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.BinNumExpr(
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_statsmodels_glm_inverse_power_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.BinNumExpr(
Expand All @@ -489,7 +489,7 @@ def test_statsmodels_glm_inverse_squared_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.BinNumExpr(
Expand All @@ -516,7 +516,7 @@ def test_statsmodels_glm_sqr_power_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.SqrtExpr(
Expand All @@ -540,7 +540,7 @@ def test_statsmodels_glm_identity_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2], [3]], [0.1, 0.2, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.BinNumExpr(
Expand All @@ -563,7 +563,7 @@ def test_statsmodels_glm_sqrt_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.PowExpr(
Expand All @@ -588,7 +588,7 @@ def test_statsmodels_glm_log_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.ExpExpr(
Expand All @@ -612,7 +612,7 @@ def test_statsmodels_glm_cloglog_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.BinNumExpr(
Expand Down Expand Up @@ -643,7 +643,7 @@ def test_statsmodels_glm_negativebinomial_link_func():
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
actual = assembler.assemble()

expected = ast.BinNumExpr(
Expand Down Expand Up @@ -683,7 +683,7 @@ class ValidPowerLink(sm.families.links.Power):
fit=dict(maxiter=1)))
estimator = estimator.fit([[1], [2]], [0.1, 0.2])

assembler = assemblers.StatsmodelsGLMModelAssembler(estimator)
assembler = assemblers.StatsmodelsModelAssemblerSelector(estimator)
assembler.assemble()


Expand Down
18 changes: 18 additions & 0 deletions tests/test_ast.py
Expand Up @@ -92,6 +92,24 @@ def test_exprs_hash():
assert hash(EXPR_WITH_ALL_EXPRS) == hash(expr_copy)


def test_exprs_str():
assert str(EXPR_WITH_ALL_EXPRS) == """
BinVectorNumExpr(BinVectorExpr(VectorVal([
AbsExpr(NumVal(-2),to_reuse=False),
ExpExpr(NumVal(2),to_reuse=False),
LogExpr(NumVal(2),to_reuse=False),
Log1pExpr(NumVal(2),to_reuse=False),
SqrtExpr(NumVal(2),to_reuse=False),
PowExpr(NumVal(2),NumVal(3),to_reuse=False),
TanhExpr(NumVal(1),to_reuse=False),
BinNumExpr(NumVal(0),FeatureRef(0),to_reuse=False)]),
IdExpr(VectorVal([
NumVal(1),NumVal(2),NumVal(3),NumVal(4),NumVal(5),NumVal(6),NumVal(7),
FeatureRef(1)]),to_reuse=False),SUB),
IfExpr(CompExpr(NumVal(2),NumVal(0),GT),NumVal(3),NumVal(4)),MUL)
""".strip().replace("\n", "")


def test_num_val():
assert type(ast.NumVal(1).value) == int
assert type(ast.NumVal(1, dtype=np.float32).value) == np.float32
Expand Down
1 change: 0 additions & 1 deletion tests/test_fallback_expressions.py
Expand Up @@ -11,7 +11,6 @@ def test_abs_fallback_expr():
interpreter.abs_function_name = NotImplemented

expected_code = """
#include <math.h>
double score(double * input) {
double var0;
double var1;
Expand Down

0 comments on commit 9505091

Please sign in to comment.