Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added log and log1p functions #225

Merged
merged 7 commits into from Jun 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
@@ -1,5 +1,6 @@
include LICENSE
recursive-include m2cgen VERSION.txt
recursive-include m2cgen linear_algebra.*
recursive-include m2cgen log1p.*
recursive-include m2cgen tanh.*
global-exclude *.py[cod]
12 changes: 12 additions & 0 deletions m2cgen/assemblers/fallback_expressions.py
Expand Up @@ -54,6 +54,18 @@ def exp(expr, to_reuse=False):
to_reuse=to_reuse)


def log1p(expr):
# Use trick to compute log1p for small values more accurate
# https://www.johndcook.com/blog/2012/07/25/trick-for-computing-log1x/
expr = ast.IdExpr(expr, to_reuse=True)
expr1p = utils.add(ast.NumVal(1.0), expr, to_reuse=True)
expr1pm1 = utils.sub(expr1p, ast.NumVal(1.0), to_reuse=True)
return ast.IfExpr(
utils.eq(expr1pm1, ast.NumVal(0.0)),
expr,
utils.div(utils.mul(expr, ast.LogExpr(expr1p)), expr1pm1))


def sigmoid(expr, to_reuse=False):
neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB)
exp_expr = ast.ExpExpr(neg_expr)
Expand Down
39 changes: 38 additions & 1 deletion m2cgen/ast.py
Expand Up @@ -106,6 +106,42 @@ def __hash__(self):
return hash(self.expr)


class LogExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"

self.expr = expr
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "LogExpr(" + args + ")"

def __eq__(self, other):
return type(other) is LogExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class Log1pExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"

self.expr = expr
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "Log1pExpr(" + args + ")"

def __eq__(self, other):
return type(other) is Log1pExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class SqrtExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -354,7 +390,8 @@ def __hash__(self):
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((AbsExpr, IdExpr, ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]),
((AbsExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, SqrtExpr, TanhExpr),
lambda e: [e.expr]),
]


Expand Down
2 changes: 2 additions & 0 deletions m2cgen/interpreters/c/interpreter.py
Expand Up @@ -19,6 +19,8 @@ class CInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "fabs"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
power_function_name = "pow"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"
Expand Down
13 changes: 13 additions & 0 deletions m2cgen/interpreters/c_sharp/interpreter.py
Expand Up @@ -20,10 +20,14 @@ class CSharpInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "Abs"
exponent_function_name = "Exp"
logarithm_function_name = "Log"
log1p_function_name = "Log1p"
power_function_name = "Pow"
sqrt_function_name = "Sqrt"
tanh_function_name = "Tanh"

with_log1p_expr = False

def __init__(self, namespace="ML", class_name="Model", indent=4,
function_name="Score", *args, **kwargs):
self.namespace = namespace
Expand Down Expand Up @@ -56,7 +60,16 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.cs")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_log1p_expr:
filename = os.path.join(
os.path.dirname(__file__), "log1p.cs")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_math_module:
self._cg.add_dependency("System.Math")

return self._cg.finalize_and_get_generated_code()

def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)
51 changes: 51 additions & 0 deletions m2cgen/interpreters/c_sharp/log1p.cs
@@ -0,0 +1,51 @@
private static double Log1p(double x) {
if (x == 0.0)
return 0.0;
if (x == -1.0)
return double.NegativeInfinity;
if (x < -1.0)
return double.NaN;
double xAbs = Abs(x);
if (xAbs < 0.5 * double.Epsilon)
return x;
if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0))
return x * (1.0 - x * 0.5);
if (xAbs < 0.375) {
double[] coeffs = {
0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15};
return x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs));
}
return Log(1.0 + x);
}
private static double ChebyshevBroucke(double x, double[] coeffs) {
double b0, b1, b2, x2;
b2 = b1 = b0 = 0.0;
x2 = x * 2;
for (int i = coeffs.Length - 1; i >= 0; --i) {
b2 = b1;
b1 = b0;
b0 = x2 * b1 - b2 + coeffs[i];
}
return (b0 - b2) * 0.5;
}
16 changes: 13 additions & 3 deletions m2cgen/interpreters/dart/interpreter.py
Expand Up @@ -23,10 +23,13 @@ class DartInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "abs"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
power_function_name = "pow"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_log1p_expr = False
with_tanh_expr = False

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down Expand Up @@ -54,7 +57,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.dart")
self._cg.add_code_lines(utils.get_file_content(filename))

# Use own tanh function in order to be compatible with Dart
if self.with_log1p_expr:
filename = os.path.join(
os.path.dirname(__file__), "log1p.dart")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_tanh_expr:
filename = os.path.join(
os.path.dirname(__file__), "tanh.dart")
Expand All @@ -71,7 +78,10 @@ def interpret_abs_expr(self, expr, **kwargs):
obj=self._do_interpret(expr.expr, **kwargs),
args=[])

def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)

def interpret_tanh_expr(self, expr, **kwargs):
self.with_tanh_expr = True
return super(
DartInterpreter, self).interpret_tanh_expr(expr, **kwargs)
return super().interpret_tanh_expr(expr, **kwargs)
51 changes: 51 additions & 0 deletions m2cgen/interpreters/dart/log1p.dart
@@ -0,0 +1,51 @@
double log1p(double x) {
if (x == 0.0)
return 0.0;
if (x == -1.0)
return double.negativeInfinity;
if (x < -1.0)
return double.nan;
double xAbs = x.abs();
if (xAbs < 0.5 * 4.94065645841247e-324)
return x;
if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0))
return x * (1.0 - x * 0.5);
if (xAbs < 0.375) {
List<double> coeffs = [
0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15];
return x * (1.0 - x * chebyshevBroucke(x / 0.375, coeffs));
}
return log(1.0 + x);
}
double chebyshevBroucke(double x, List<double> coeffs) {
double b0, b1, b2, x2;
b2 = b1 = b0 = 0.0;
x2 = x * 2;
for (int i = coeffs.length - 1; i >= 0; --i) {
b2 = b1;
b1 = b0;
b0 = x2 * b1 - b2 + coeffs[i];
}
return (b0 - b2) * 0.5;
}
2 changes: 2 additions & 0 deletions m2cgen/interpreters/go/interpreter.py
Expand Up @@ -18,6 +18,8 @@ class GoInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "math.Abs"
exponent_function_name = "math.Exp"
logarithm_function_name = "math.Log"
log1p_function_name = "math.Log1p"
power_function_name = "math.Pow"
sqrt_function_name = "math.Sqrt"
tanh_function_name = "math.Tanh"
Expand Down
13 changes: 13 additions & 0 deletions m2cgen/interpreters/haskell/interpreter.py
Expand Up @@ -18,9 +18,13 @@ class HaskellInterpreter(ToCodeInterpreter,

abs_function_name = "abs"
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_log1p_expr = False

def __init__(self, module_name="Model", indent=4, function_name="score",
*args, **kwargs):
self.module_name = module_name
Expand Down Expand Up @@ -50,6 +54,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.hs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_log1p_expr:
filename = os.path.join(
os.path.dirname(__file__), "log1p.hs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

self._cg.prepend_code_line(self._cg.tpl_module_definition(
module_name=self.module_name))

Expand Down Expand Up @@ -82,6 +91,10 @@ def interpret_pow_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=base_result, right=exp_result, op="**")

def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)

# Cached expressions become functions with no arguments, i.e. values
# which are CAFs. Therefore, they are computed only once.
def _cache_reused_expr(self, expr, expr_result):
Expand Down
40 changes: 40 additions & 0 deletions m2cgen/interpreters/haskell/log1p.hs
@@ -0,0 +1,40 @@
log1p :: Double -> Double
log1p x
| x == 0 = 0
| x == -1 = -1 / 0
| x < -1 = 0 / 0
| x' < m_epsilon * 0.5 = x
| (x > 0 && x < 1e-8) || (x > -1e-9 && x < 0)
= x * (1 - x * 0.5)
| x' < 0.375 = x * (1 - x * chebyshevBroucke (x / 0.375) coeffs)
| otherwise = log (1 + x)
where
m_epsilon = encodeFloat (signif + 1) expo - 1.0
where (signif, expo) = decodeFloat (1.0::Double)
x' = abs x
coeffs = [0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15]
chebyshevBroucke i = fini . foldr step [0, 0, 0]
where
step k [b0, b1, _] = [(k + i * 2 * b0 - b1), b0, b1]
fini [b0, _, b2] = (b0 - b2) * 0.5
19 changes: 19 additions & 0 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -83,6 +83,8 @@ class ToCodeInterpreter(BaseToCodeInterpreter):

abs_function_name = NotImplemented
exponent_function_name = NotImplemented
logarithm_function_name = NotImplemented
log1p_function_name = NotImplemented
power_function_name = NotImplemented
sqrt_function_name = NotImplemented
tanh_function_name = NotImplemented
Expand Down Expand Up @@ -140,6 +142,23 @@ def interpret_exp_expr(self, expr, **kwargs):
return self._cg.function_invocation(
self.exponent_function_name, nested_result)

def interpret_log_expr(self, expr, **kwargs):
if self.logarithm_function_name is NotImplemented:
raise NotImplementedError("Logarithm function is not provided")
Comment on lines +146 to +147
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I didn't find anything that we could use as a fallback expression for LogExpr. All methods are based on reduction technique which requires frexp implementation.
Refer for example to https://golang.org/src/math/log.go

self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.logarithm_function_name, nested_result)

def interpret_log1p_expr(self, expr, **kwargs):
self.with_math_module = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should set this flag only in case when we don't follow the fallback path. In case of the fallback we technically don't require the math module, although subsequent expressions that constitute the fallback expression may choose to enable it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't believe you ported the log1p function to all those languages

Thanks! Actually it was not so hard. All imperative languages are so similar! 😄

I believe we should set this flag only in case when we don't follow the fallback path.

Sorry, I'm not sure I fully understood you. Do you mean the following?

if self.log1p_function_name is NotImplemented:
            return self._do_interpret(
                fallback_expressions.log1p(expr.expr), **kwargs)
self.with_math_module = True

If so, it seems reasonable for me. But we should double check cases when fallbacks use functions from math module.

Anyway, I believe that this should be addressed in a separate PR because all other expressions follow the same pattern for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is exactly what I mean!

But we should double check cases when fallbacks use functions from math module.

Yep, but those functions will also be interpreted and the flag will be set accordingly if needed.

Anyway, I believe that this should be addressed in a separate PR because all other expressions follow the same pattern for now.

I'm fine with that. We may want to address this for other functions as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I created new issue for that to not forget: #246.

if self.log1p_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.log1p(expr.expr), **kwargs)
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.log1p_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
self.with_math_module = True
if self.sqrt_function_name is NotImplemented:
Expand Down
2 changes: 2 additions & 0 deletions m2cgen/interpreters/java/interpreter.py
Expand Up @@ -26,6 +26,8 @@ class JavaInterpreter(ImperativeToCodeInterpreter,

abs_function_name = "Math.abs"
exponent_function_name = "Math.exp"
logarithm_function_name = "Math.log"
log1p_function_name = "Math.log1p"
power_function_name = "Math.pow"
sqrt_function_name = "Math.sqrt"
tanh_function_name = "Math.tanh"
Expand Down