Skip to content

Commit

Permalink
fixed conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jun 23, 2020
2 parents 1fa17aa + c4b8d29 commit e5e748f
Show file tree
Hide file tree
Showing 39 changed files with 1,197 additions and 8 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
@@ -1,5 +1,6 @@
include LICENSE
recursive-include m2cgen VERSION.txt
recursive-include m2cgen linear_algebra.*
recursive-include m2cgen log1p.*
recursive-include m2cgen tanh.*
global-exclude *.py[cod]
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -38,7 +38,7 @@ pip install m2cgen
- Python
- R
- Ruby
- Visual Basic
- Visual Basic (VBA-compatible)

## Supported Models

Expand Down
12 changes: 12 additions & 0 deletions m2cgen/assemblers/fallback_expressions.py
Expand Up @@ -54,6 +54,18 @@ def exp(expr, to_reuse=False):
to_reuse=to_reuse)


def log1p(expr):
# Use trick to compute log1p for small values more accurate
# https://www.johndcook.com/blog/2012/07/25/trick-for-computing-log1x/
expr = ast.IdExpr(expr, to_reuse=True)
expr1p = utils.add(ast.NumVal(1.0), expr, to_reuse=True)
expr1pm1 = utils.sub(expr1p, ast.NumVal(1.0), to_reuse=True)
return ast.IfExpr(
utils.eq(expr1pm1, ast.NumVal(0.0)),
expr,
utils.div(utils.mul(expr, ast.LogExpr(expr1p)), expr1pm1))


def sigmoid(expr, to_reuse=False):
neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB)
exp_expr = ast.ExpExpr(neg_expr)
Expand Down
39 changes: 38 additions & 1 deletion m2cgen/ast.py
Expand Up @@ -106,6 +106,42 @@ def __hash__(self):
return hash(self.expr)


class LogExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"

self.expr = expr
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "LogExpr(" + args + ")"

def __eq__(self, other):
return type(other) is LogExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class Log1pExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"

self.expr = expr
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "Log1pExpr(" + args + ")"

def __eq__(self, other):
return type(other) is Log1pExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class SqrtExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -354,7 +390,8 @@ def __hash__(self):
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((AbsExpr, IdExpr, ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]),
((AbsExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, SqrtExpr, TanhExpr),
lambda e: [e.expr]),
]


Expand Down
2 changes: 2 additions & 0 deletions m2cgen/interpreters/c/interpreter.py
Expand Up @@ -19,6 +19,8 @@ class CInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "fabs"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
power_function_name = "pow"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"
Expand Down
13 changes: 13 additions & 0 deletions m2cgen/interpreters/c_sharp/interpreter.py
Expand Up @@ -20,10 +20,14 @@ class CSharpInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "Abs"
exponent_function_name = "Exp"
logarithm_function_name = "Log"
log1p_function_name = "Log1p"
power_function_name = "Pow"
sqrt_function_name = "Sqrt"
tanh_function_name = "Tanh"

with_log1p_expr = False

def __init__(self, namespace="ML", class_name="Model", indent=4,
function_name="Score", *args, **kwargs):
self.namespace = namespace
Expand Down Expand Up @@ -56,7 +60,16 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.cs")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_log1p_expr:
filename = os.path.join(
os.path.dirname(__file__), "log1p.cs")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_math_module:
self._cg.add_dependency("System.Math")

return self._cg.finalize_and_get_generated_code()

def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)
51 changes: 51 additions & 0 deletions m2cgen/interpreters/c_sharp/log1p.cs
@@ -0,0 +1,51 @@
private static double Log1p(double x) {
if (x == 0.0)
return 0.0;
if (x == -1.0)
return double.NegativeInfinity;
if (x < -1.0)
return double.NaN;
double xAbs = Abs(x);
if (xAbs < 0.5 * double.Epsilon)
return x;
if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0))
return x * (1.0 - x * 0.5);
if (xAbs < 0.375) {
double[] coeffs = {
0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15};
return x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs));
}
return Log(1.0 + x);
}
private static double ChebyshevBroucke(double x, double[] coeffs) {
double b0, b1, b2, x2;
b2 = b1 = b0 = 0.0;
x2 = x * 2;
for (int i = coeffs.Length - 1; i >= 0; --i) {
b2 = b1;
b1 = b0;
b0 = x2 * b1 - b2 + coeffs[i];
}
return (b0 - b2) * 0.5;
}
16 changes: 13 additions & 3 deletions m2cgen/interpreters/dart/interpreter.py
Expand Up @@ -23,10 +23,13 @@ class DartInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "abs"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
power_function_name = "pow"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_log1p_expr = False
with_tanh_expr = False

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down Expand Up @@ -54,7 +57,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.dart")
self._cg.add_code_lines(utils.get_file_content(filename))

# Use own tanh function in order to be compatible with Dart
if self.with_log1p_expr:
filename = os.path.join(
os.path.dirname(__file__), "log1p.dart")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_tanh_expr:
filename = os.path.join(
os.path.dirname(__file__), "tanh.dart")
Expand All @@ -71,7 +78,10 @@ def interpret_abs_expr(self, expr, **kwargs):
obj=self._do_interpret(expr.expr, **kwargs),
args=[])

def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)

def interpret_tanh_expr(self, expr, **kwargs):
self.with_tanh_expr = True
return super(
DartInterpreter, self).interpret_tanh_expr(expr, **kwargs)
return super().interpret_tanh_expr(expr, **kwargs)
51 changes: 51 additions & 0 deletions m2cgen/interpreters/dart/log1p.dart
@@ -0,0 +1,51 @@
double log1p(double x) {
if (x == 0.0)
return 0.0;
if (x == -1.0)
return double.negativeInfinity;
if (x < -1.0)
return double.nan;
double xAbs = x.abs();
if (xAbs < 0.5 * 4.94065645841247e-324)
return x;
if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0))
return x * (1.0 - x * 0.5);
if (xAbs < 0.375) {
List<double> coeffs = [
0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15];
return x * (1.0 - x * chebyshevBroucke(x / 0.375, coeffs));
}
return log(1.0 + x);
}
double chebyshevBroucke(double x, List<double> coeffs) {
double b0, b1, b2, x2;
b2 = b1 = b0 = 0.0;
x2 = x * 2;
for (int i = coeffs.length - 1; i >= 0; --i) {
b2 = b1;
b1 = b0;
b0 = x2 * b1 - b2 + coeffs[i];
}
return (b0 - b2) * 0.5;
}
2 changes: 2 additions & 0 deletions m2cgen/interpreters/go/interpreter.py
Expand Up @@ -18,6 +18,8 @@ class GoInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "math.Abs"
exponent_function_name = "math.Exp"
logarithm_function_name = "math.Log"
log1p_function_name = "math.Log1p"
power_function_name = "math.Pow"
sqrt_function_name = "math.Sqrt"
tanh_function_name = "math.Tanh"
Expand Down
13 changes: 13 additions & 0 deletions m2cgen/interpreters/haskell/interpreter.py
Expand Up @@ -18,9 +18,13 @@ class HaskellInterpreter(ToCodeInterpreter,

abs_function_name = "abs"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_log1p_expr = False

def __init__(self, module_name="Model", indent=4, function_name="score",
*args, **kwargs):
self.module_name = module_name
Expand Down Expand Up @@ -50,6 +54,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.hs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_log1p_expr:
filename = os.path.join(
os.path.dirname(__file__), "log1p.hs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

self._cg.prepend_code_line(self._cg.tpl_module_definition(
module_name=self.module_name))

Expand Down Expand Up @@ -82,6 +91,10 @@ def interpret_pow_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=base_result, right=exp_result, op="**")

def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)

# Cached expressions become functions with no arguments, i.e. values
# which are CAFs. Therefore, they are computed only once.
def _cache_reused_expr(self, expr, expr_result):
Expand Down
40 changes: 40 additions & 0 deletions m2cgen/interpreters/haskell/log1p.hs
@@ -0,0 +1,40 @@
log1p :: Double -> Double
log1p x
| x == 0 = 0
| x == -1 = -1 / 0
| x < -1 = 0 / 0
| x' < m_epsilon * 0.5 = x
| (x > 0 && x < 1e-8) || (x > -1e-9 && x < 0)
= x * (1 - x * 0.5)
| x' < 0.375 = x * (1 - x * chebyshevBroucke (x / 0.375) coeffs)
| otherwise = log (1 + x)
where
m_epsilon = encodeFloat (signif + 1) expo - 1.0
where (signif, expo) = decodeFloat (1.0::Double)
x' = abs x
coeffs = [0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15]
chebyshevBroucke i = fini . foldr step [0, 0, 0]
where
step k [b0, b1, _] = [(k + i * 2 * b0 - b1), b0, b1]
fini [b0, _, b2] = (b0 - b2) * 0.5
19 changes: 19 additions & 0 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -83,6 +83,8 @@ class ToCodeInterpreter(BaseToCodeInterpreter):

abs_function_name = NotImplemented
exponent_function_name = NotImplemented
logarithm_function_name = NotImplemented
log1p_function_name = NotImplemented
power_function_name = NotImplemented
sqrt_function_name = NotImplemented
tanh_function_name = NotImplemented
Expand Down Expand Up @@ -140,6 +142,23 @@ def interpret_exp_expr(self, expr, **kwargs):
return self._cg.function_invocation(
self.exponent_function_name, nested_result)

def interpret_log_expr(self, expr, **kwargs):
if self.logarithm_function_name is NotImplemented:
raise NotImplementedError("Logarithm function is not provided")
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.logarithm_function_name, nested_result)

def interpret_log1p_expr(self, expr, **kwargs):
self.with_math_module = True
if self.log1p_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.log1p(expr.expr), **kwargs)
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.log1p_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
self.with_math_module = True
if self.sqrt_function_name is NotImplemented:
Expand Down

0 comments on commit e5e748f

Please sign in to comment.