Skip to content

Commit

Permalink
Merge branch 'master' into sklearn_glm
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jun 19, 2020
2 parents 97a075f + 813d2e9 commit aa33a19
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 63 deletions.
32 changes: 16 additions & 16 deletions m2cgen/assemblers/utils.py
Expand Up @@ -3,36 +3,36 @@
from m2cgen import ast


def mul(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.MUL, to_reuse=to_reuse)
def mul(left, right, to_reuse=False):
return ast.BinNumExpr(left, right, ast.BinNumOpType.MUL, to_reuse=to_reuse)


def div(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.DIV, to_reuse=to_reuse)
def div(left, right, to_reuse=False):
return ast.BinNumExpr(left, right, ast.BinNumOpType.DIV, to_reuse=to_reuse)


def add(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.ADD, to_reuse=to_reuse)
def add(left, right, to_reuse=False):
return ast.BinNumExpr(left, right, ast.BinNumOpType.ADD, to_reuse=to_reuse)


def sub(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.SUB, to_reuse=to_reuse)
def sub(left, right, to_reuse=False):
return ast.BinNumExpr(left, right, ast.BinNumOpType.SUB, to_reuse=to_reuse)


def lt(l, r):
return ast.CompExpr(l, r, ast.CompOpType.LT)
def lt(left, right):
return ast.CompExpr(left, right, ast.CompOpType.LT)


def lte(l, r):
return ast.CompExpr(l, r, ast.CompOpType.LTE)
def lte(left, right):
return ast.CompExpr(left, right, ast.CompOpType.LTE)


def gt(l, r):
return ast.CompExpr(l, r, ast.CompOpType.GT)
def gt(left, right):
return ast.CompExpr(left, right, ast.CompOpType.GT)


def eq(l, r):
return ast.CompExpr(l, r, ast.CompOpType.EQ)
def eq(left, right):
return ast.CompExpr(left, right, ast.CompOpType.EQ)


BIN_EXPR_CLASSES = {
Expand Down
99 changes: 99 additions & 0 deletions m2cgen/ast.py
Expand Up @@ -23,6 +23,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "IdExpr(" + args + ")"

def __eq__(self, other):
return type(other) is IdExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class FeatureRef(Expr):
def __init__(self, index):
Expand All @@ -31,6 +37,12 @@ def __init__(self, index):
def __str__(self):
return "FeatureRef(" + str(self.index) + ")"

def __eq__(self, other):
return type(other) is FeatureRef and self.index == other.index

def __hash__(self):
return hash(self.index)


class BinExpr(Expr):
pass
Expand All @@ -51,6 +63,12 @@ def __init__(self, value, dtype=None):
def __str__(self):
return "NumVal(" + str(self.value) + ")"

def __eq__(self, other):
return type(other) is NumVal and self.value == other.value

def __hash__(self):
return hash(self.value)


class AbsExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -63,6 +81,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "AbsExpr(" + args + ")"

def __eq__(self, other):
return type(other) is AbsExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class ExpExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -75,6 +99,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "ExpExpr(" + args + ")"

def __eq__(self, other):
return type(other) is ExpExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class SqrtExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -87,6 +117,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "SqrtExpr(" + args + ")"

def __eq__(self, other):
return type(other) is SqrtExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class TanhExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -99,6 +135,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "TanhExpr(" + args + ")"

def __eq__(self, other):
return type(other) is TanhExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class PowExpr(NumExpr):
def __init__(self, base_expr, exp_expr, to_reuse=False):
Expand All @@ -117,6 +159,14 @@ def __str__(self):
])
return "PowExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is PowExpr and
self.base_expr == other.base_expr and
self.exp_expr == other.exp_expr)

def __hash__(self):
return hash((self.base_expr, self.exp_expr))


class BinNumOpType(Enum):
ADD = '+'
Expand Down Expand Up @@ -144,6 +194,12 @@ def __str__(self):
])
return "BinNumExpr(" + args + ")"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))

def __hash__(self):
return hash((self.left, self.right, self.op))


# Vector Expressions.

Expand All @@ -164,6 +220,14 @@ def __str__(self):
args = ",".join([str(e) for e in self.exprs])
return "VectorVal([" + args + "])"

def __eq__(self, other):
return (type(other) is VectorVal and
self.output_size == other.output_size and
all(i == j for i, j in zip(self.exprs, other.exprs)))

def __hash__(self):
return hash(tuple(self.exprs))


class BinVectorExpr(VectorExpr, BinExpr):

Expand All @@ -181,6 +245,12 @@ def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorExpr(" + args + ")"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))

def __hash__(self):
return hash((self.left, self.right, self.op))


class BinVectorNumExpr(VectorExpr, BinExpr):

Expand All @@ -197,6 +267,12 @@ def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorNumExpr(" + args + ")"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))

def __hash__(self):
return hash((self.left, self.right, self.op))


# Boolean Expressions.

Expand Down Expand Up @@ -233,6 +309,12 @@ def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "CompExpr(" + args + ")"

def __eq__(self, other):
return _eq_bin_exprs(self, other, type(self))

def __hash__(self):
return hash((self.left, self.right, self.op))


# Control Expressions.

Expand All @@ -254,6 +336,15 @@ def __str__(self):
args = ",".join([str(self.test), str(self.body), str(self.orelse)])
return "IfExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is IfExpr and
self.test == other.test and
self.body == other.body and
self.orelse == other.orelse)

def __hash__(self):
return hash((self.test, self.body, self.orelse))


TOTAL_NUMBER_OF_EXPRESSIONS = len(getmembers(modules[__name__], isclass))

Expand Down Expand Up @@ -286,3 +377,11 @@ def count_exprs(expr, exclude_list=None):

expr_type_name = expr_type.__name__
raise ValueError("Unexpected expression type '{}'".format(expr_type_name))


def _eq_bin_exprs(expr_one, expr_two, expected_type):
return (type(expr_one) is expected_type and
type(expr_two) is expected_type and
expr_one.left == expr_two.left and
expr_one.right == expr_two.right and
expr_one.op == expr_two.op)
6 changes: 3 additions & 3 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -29,6 +29,9 @@ def _do_interpret(self, expr, to_reuse=None, **kwargs):
if result is not None:
return result

if expr in self._cached_expr_results:
return self._cached_expr_results[expr].var_name

handler = self._select_handler(expr)

# Note that the reuse flag passed in the arguments has a higher
Expand All @@ -40,9 +43,6 @@ def _do_interpret(self, expr, to_reuse=None, **kwargs):
if not expr_to_reuse:
return handler(expr, **kwargs)

if expr in self._cached_expr_results:
return self._cached_expr_results[expr].var_name

result = handler(expr, **kwargs)
return self._cache_reused_expr(expr, result)

Expand Down
4 changes: 3 additions & 1 deletion m2cgen/interpreters/mixins.py
Expand Up @@ -36,8 +36,10 @@ def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs):

# Default implementation. Simply adds new variable.
def bin_depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
if expr in self._cached_expr_results:
return self._cached_expr_results[expr].var_name
result = self._do_interpret(expr, **kwargs)
var_name = self._cg.add_var_declaration(expr.output_size)
self._cg.add_var_assignment(var_name, result, expr.output_size)
return var_name

Expand Down
4 changes: 2 additions & 2 deletions requirements-test.txt
@@ -1,12 +1,12 @@
# Supported models
scikit-learn==0.23.1
xgboost==1.0.2
xgboost==1.1.1
lightgbm==2.3.1
statsmodels==0.11.1
git+git://github.com/scikit-learn-contrib/lightning.git@782c18c12961e509099ae84c68dd361010642f7e

# Testing tools
flake8==3.7.9
flake8==3.8.3
pytest==5.4.3
pytest-mock==3.1.1
coveralls==2.0.0
Expand Down
20 changes: 10 additions & 10 deletions tests/interpreters/test_dart.py
Expand Up @@ -333,12 +333,12 @@ def test_deep_mixed_exprs_exceeding_threshold():
expr = ast.NumVal(1)
for i in range(4):
inner = ast.NumVal(1)
for i in range(4):
inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)
for j in range(4):
inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

expr = ast.IfExpr(
ast.CompExpr(
inner, ast.NumVal(1), ast.CompOpType.EQ),
inner, ast.NumVal(j), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

Expand All @@ -348,23 +348,23 @@ def test_deep_mixed_exprs_exceeding_threshold():
double score(List<double> input) {
double var0;
double var1;
var1 = (1) + ((1) + (1));
if (((1) + ((1) + (var1))) == (1)) {
var1 = (3) + ((3) + (1));
if (((3) + ((3) + (var1))) == (3)) {
var0 = 1;
} else {
double var2;
var2 = (1) + ((1) + (1));
if (((1) + ((1) + (var2))) == (1)) {
var2 = (2) + ((2) + (1));
if (((2) + ((2) + (var2))) == (3)) {
var0 = 1;
} else {
double var3;
var3 = (1) + ((1) + (1));
if (((1) + ((1) + (var3))) == (1)) {
if (((1) + ((1) + (var3))) == (3)) {
var0 = 1;
} else {
double var4;
var4 = (1) + ((1) + (1));
if (((1) + ((1) + (var4))) == (1)) {
var4 = (0) + ((0) + (1));
if (((0) + ((0) + (var4))) == (3)) {
var0 = 1;
} else {
var0 = 1;
Expand Down
20 changes: 10 additions & 10 deletions tests/interpreters/test_python.py
Expand Up @@ -281,33 +281,33 @@ def test_deep_mixed_exprs_exceeding_threshold():
expr = ast.NumVal(1)
for i in range(4):
inner = ast.NumVal(1)
for i in range(4):
inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)
for j in range(4):
inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

expr = ast.IfExpr(
ast.CompExpr(
inner, ast.NumVal(1), ast.CompOpType.EQ),
inner, ast.NumVal(j), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

interpreter = CustomPythonInterpreter()

expected_code = """
def score(input):
var1 = (1) + ((1) + (1))
if ((1) + ((1) + (var1))) == (1):
var1 = (3) + ((3) + (1))
if ((3) + ((3) + (var1))) == (3):
var0 = 1
else:
var2 = (1) + ((1) + (1))
if ((1) + ((1) + (var2))) == (1):
var2 = (2) + ((2) + (1))
if ((2) + ((2) + (var2))) == (3):
var0 = 1
else:
var3 = (1) + ((1) + (1))
if ((1) + ((1) + (var3))) == (1):
if ((1) + ((1) + (var3))) == (3):
var0 = 1
else:
var4 = (1) + ((1) + (1))
if ((1) + ((1) + (var4))) == (1):
var4 = (0) + ((0) + (1))
if ((0) + ((0) + (var4))) == (3):
var0 = 1
else:
var0 = 1
Expand Down

0 comments on commit aa33a19

Please sign in to comment.