diff --git a/MANIFEST.in b/MANIFEST.in index 710be031..90f00e63 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include LICENSE recursive-include m2cgen VERSION.txt recursive-include m2cgen linear_algebra.* +recursive-include m2cgen log1p.* recursive-include m2cgen tanh.* global-exclude *.py[cod] diff --git a/README.md b/README.md index eaede06e..bd7292ec 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ pip install m2cgen - Python - R - Ruby -- Visual Basic +- Visual Basic (VBA-compatible) ## Supported Models diff --git a/m2cgen/assemblers/fallback_expressions.py b/m2cgen/assemblers/fallback_expressions.py index e7ed1b72..cd049c47 100644 --- a/m2cgen/assemblers/fallback_expressions.py +++ b/m2cgen/assemblers/fallback_expressions.py @@ -54,6 +54,18 @@ def exp(expr, to_reuse=False): to_reuse=to_reuse) +def log1p(expr): + # Use trick to compute log1p for small values more accurate + # https://www.johndcook.com/blog/2012/07/25/trick-for-computing-log1x/ + expr = ast.IdExpr(expr, to_reuse=True) + expr1p = utils.add(ast.NumVal(1.0), expr, to_reuse=True) + expr1pm1 = utils.sub(expr1p, ast.NumVal(1.0), to_reuse=True) + return ast.IfExpr( + utils.eq(expr1pm1, ast.NumVal(0.0)), + expr, + utils.div(utils.mul(expr, ast.LogExpr(expr1p)), expr1pm1)) + + def sigmoid(expr, to_reuse=False): neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB) exp_expr = ast.ExpExpr(neg_expr) diff --git a/m2cgen/ast.py b/m2cgen/ast.py index f9e4f004..757885d5 100644 --- a/m2cgen/ast.py +++ b/m2cgen/ast.py @@ -106,6 +106,42 @@ def __hash__(self): return hash(self.expr) +class LogExpr(NumExpr): + def __init__(self, expr, to_reuse=False): + assert expr.output_size == 1, "Only scalars are supported" + + self.expr = expr + self.to_reuse = to_reuse + + def __str__(self): + args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) + return "LogExpr(" + args + ")" + + def __eq__(self, other): + return type(other) is LogExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + + +class Log1pExpr(NumExpr): + def __init__(self, expr, to_reuse=False): + assert expr.output_size == 1, "Only scalars are supported" + + self.expr = expr + self.to_reuse = to_reuse + + def __str__(self): + args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) + return "Log1pExpr(" + args + ")" + + def __eq__(self, other): + return type(other) is Log1pExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + + class SqrtExpr(NumExpr): def __init__(self, expr, to_reuse=False): assert expr.output_size == 1, "Only scalars are supported" @@ -354,7 +390,8 @@ def __hash__(self): (PowExpr, lambda e: [e.base_expr, e.exp_expr]), (VectorVal, lambda e: e.exprs), (IfExpr, lambda e: [e.test, e.body, e.orelse]), - ((AbsExpr, IdExpr, ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]), + ((AbsExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, SqrtExpr, TanhExpr), + lambda e: [e.expr]), ] diff --git a/m2cgen/interpreters/c/interpreter.py b/m2cgen/interpreters/c/interpreter.py index 4ff3fe05..8ff4a6c7 100644 --- a/m2cgen/interpreters/c/interpreter.py +++ b/m2cgen/interpreters/c/interpreter.py @@ -19,6 +19,8 @@ class CInterpreter(ImperativeToCodeInterpreter, abs_function_name = "fabs" exponent_function_name = "exp" + logarithm_function_name = "log" + log1p_function_name = "log1p" power_function_name = "pow" sqrt_function_name = "sqrt" tanh_function_name = "tanh" diff --git a/m2cgen/interpreters/c_sharp/interpreter.py b/m2cgen/interpreters/c_sharp/interpreter.py index d2227710..c9add466 100644 --- a/m2cgen/interpreters/c_sharp/interpreter.py +++ b/m2cgen/interpreters/c_sharp/interpreter.py @@ -20,10 +20,14 @@ class CSharpInterpreter(ImperativeToCodeInterpreter, abs_function_name = "Abs" exponent_function_name = "Exp" + logarithm_function_name = "Log" + log1p_function_name = "Log1p" power_function_name = "Pow" sqrt_function_name = "Sqrt" tanh_function_name = "Tanh" + with_log1p_expr = False + def __init__(self, namespace="ML", class_name="Model", indent=4, function_name="Score", *args, **kwargs): self.namespace = namespace @@ -56,7 +60,16 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.cs") self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_log1p_expr: + filename = os.path.join( + os.path.dirname(__file__), "log1p.cs") + self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_math_module: self._cg.add_dependency("System.Math") return self._cg.finalize_and_get_generated_code() + + def interpret_log1p_expr(self, expr, **kwargs): + self.with_log1p_expr = True + return super().interpret_log1p_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/c_sharp/log1p.cs b/m2cgen/interpreters/c_sharp/log1p.cs new file mode 100644 index 00000000..ac1d7744 --- /dev/null +++ b/m2cgen/interpreters/c_sharp/log1p.cs @@ -0,0 +1,51 @@ +private static double Log1p(double x) { + if (x == 0.0) + return 0.0; + if (x == -1.0) + return double.NegativeInfinity; + if (x < -1.0) + return double.NaN; + double xAbs = Abs(x); + if (xAbs < 0.5 * double.Epsilon) + return x; + if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0)) + return x * (1.0 - x * 0.5); + if (xAbs < 0.375) { + double[] coeffs = { + 0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15}; + return x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs)); + } + return Log(1.0 + x); +} +private static double ChebyshevBroucke(double x, double[] coeffs) { + double b0, b1, b2, x2; + b2 = b1 = b0 = 0.0; + x2 = x * 2; + for (int i = coeffs.Length - 1; i >= 0; --i) { + b2 = b1; + b1 = b0; + b0 = x2 * b1 - b2 + coeffs[i]; + } + return (b0 - b2) * 0.5; +} diff --git a/m2cgen/interpreters/dart/interpreter.py b/m2cgen/interpreters/dart/interpreter.py index fd24b926..d053d9bd 100644 --- a/m2cgen/interpreters/dart/interpreter.py +++ b/m2cgen/interpreters/dart/interpreter.py @@ -23,10 +23,13 @@ class DartInterpreter(ImperativeToCodeInterpreter, abs_function_name = "abs" exponent_function_name = "exp" + logarithm_function_name = "log" + log1p_function_name = "log1p" power_function_name = "pow" sqrt_function_name = "sqrt" tanh_function_name = "tanh" + with_log1p_expr = False with_tanh_expr = False def __init__(self, indent=4, function_name="score", *args, **kwargs): @@ -54,7 +57,11 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.dart") self._cg.add_code_lines(utils.get_file_content(filename)) - # Use own tanh function in order to be compatible with Dart + if self.with_log1p_expr: + filename = os.path.join( + os.path.dirname(__file__), "log1p.dart") + self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_tanh_expr: filename = os.path.join( os.path.dirname(__file__), "tanh.dart") @@ -71,7 +78,10 @@ def interpret_abs_expr(self, expr, **kwargs): obj=self._do_interpret(expr.expr, **kwargs), args=[]) + def interpret_log1p_expr(self, expr, **kwargs): + self.with_log1p_expr = True + return super().interpret_log1p_expr(expr, **kwargs) + def interpret_tanh_expr(self, expr, **kwargs): self.with_tanh_expr = True - return super( - DartInterpreter, self).interpret_tanh_expr(expr, **kwargs) + return super().interpret_tanh_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/dart/log1p.dart b/m2cgen/interpreters/dart/log1p.dart new file mode 100644 index 00000000..2aa42894 --- /dev/null +++ b/m2cgen/interpreters/dart/log1p.dart @@ -0,0 +1,51 @@ +double log1p(double x) { + if (x == 0.0) + return 0.0; + if (x == -1.0) + return double.negativeInfinity; + if (x < -1.0) + return double.nan; + double xAbs = x.abs(); + if (xAbs < 0.5 * 4.94065645841247e-324) + return x; + if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0)) + return x * (1.0 - x * 0.5); + if (xAbs < 0.375) { + List coeffs = [ + 0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15]; + return x * (1.0 - x * chebyshevBroucke(x / 0.375, coeffs)); + } + return log(1.0 + x); +} +double chebyshevBroucke(double x, List coeffs) { + double b0, b1, b2, x2; + b2 = b1 = b0 = 0.0; + x2 = x * 2; + for (int i = coeffs.length - 1; i >= 0; --i) { + b2 = b1; + b1 = b0; + b0 = x2 * b1 - b2 + coeffs[i]; + } + return (b0 - b2) * 0.5; +} diff --git a/m2cgen/interpreters/go/interpreter.py b/m2cgen/interpreters/go/interpreter.py index 41375341..71a24b1e 100644 --- a/m2cgen/interpreters/go/interpreter.py +++ b/m2cgen/interpreters/go/interpreter.py @@ -18,6 +18,8 @@ class GoInterpreter(ImperativeToCodeInterpreter, abs_function_name = "math.Abs" exponent_function_name = "math.Exp" + logarithm_function_name = "math.Log" + log1p_function_name = "math.Log1p" power_function_name = "math.Pow" sqrt_function_name = "math.Sqrt" tanh_function_name = "math.Tanh" diff --git a/m2cgen/interpreters/haskell/interpreter.py b/m2cgen/interpreters/haskell/interpreter.py index eb340733..a1e9f593 100644 --- a/m2cgen/interpreters/haskell/interpreter.py +++ b/m2cgen/interpreters/haskell/interpreter.py @@ -18,9 +18,13 @@ class HaskellInterpreter(ToCodeInterpreter, abs_function_name = "abs" exponent_function_name = "exp" + logarithm_function_name = "log" + log1p_function_name = "log1p" sqrt_function_name = "sqrt" tanh_function_name = "tanh" + with_log1p_expr = False + def __init__(self, module_name="Model", indent=4, function_name="score", *args, **kwargs): self.module_name = module_name @@ -50,6 +54,11 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.hs") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_log1p_expr: + filename = os.path.join( + os.path.dirname(__file__), "log1p.hs") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + self._cg.prepend_code_line(self._cg.tpl_module_definition( module_name=self.module_name)) @@ -82,6 +91,10 @@ def interpret_pow_expr(self, expr, **kwargs): return self._cg.infix_expression( left=base_result, right=exp_result, op="**") + def interpret_log1p_expr(self, expr, **kwargs): + self.with_log1p_expr = True + return super().interpret_log1p_expr(expr, **kwargs) + # Cached expressions become functions with no arguments, i.e. values # which are CAFs. Therefore, they are computed only once. def _cache_reused_expr(self, expr, expr_result): diff --git a/m2cgen/interpreters/haskell/log1p.hs b/m2cgen/interpreters/haskell/log1p.hs new file mode 100644 index 00000000..511fd03a --- /dev/null +++ b/m2cgen/interpreters/haskell/log1p.hs @@ -0,0 +1,40 @@ +log1p :: Double -> Double +log1p x + | x == 0 = 0 + | x == -1 = -1 / 0 + | x < -1 = 0 / 0 + | x' < m_epsilon * 0.5 = x + | (x > 0 && x < 1e-8) || (x > -1e-9 && x < 0) + = x * (1 - x * 0.5) + | x' < 0.375 = x * (1 - x * chebyshevBroucke (x / 0.375) coeffs) + | otherwise = log (1 + x) + where + m_epsilon = encodeFloat (signif + 1) expo - 1.0 + where (signif, expo) = decodeFloat (1.0::Double) + x' = abs x + coeffs = [0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15] + chebyshevBroucke i = fini . foldr step [0, 0, 0] + where + step k [b0, b1, _] = [(k + i * 2 * b0 - b1), b0, b1] + fini [b0, _, b2] = (b0 - b2) * 0.5 diff --git a/m2cgen/interpreters/interpreter.py b/m2cgen/interpreters/interpreter.py index 536e7778..1e8fdaea 100644 --- a/m2cgen/interpreters/interpreter.py +++ b/m2cgen/interpreters/interpreter.py @@ -83,6 +83,8 @@ class ToCodeInterpreter(BaseToCodeInterpreter): abs_function_name = NotImplemented exponent_function_name = NotImplemented + logarithm_function_name = NotImplemented + log1p_function_name = NotImplemented power_function_name = NotImplemented sqrt_function_name = NotImplemented tanh_function_name = NotImplemented @@ -140,6 +142,23 @@ def interpret_exp_expr(self, expr, **kwargs): return self._cg.function_invocation( self.exponent_function_name, nested_result) + def interpret_log_expr(self, expr, **kwargs): + if self.logarithm_function_name is NotImplemented: + raise NotImplementedError("Logarithm function is not provided") + self.with_math_module = True + nested_result = self._do_interpret(expr.expr, **kwargs) + return self._cg.function_invocation( + self.logarithm_function_name, nested_result) + + def interpret_log1p_expr(self, expr, **kwargs): + self.with_math_module = True + if self.log1p_function_name is NotImplemented: + return self._do_interpret( + fallback_expressions.log1p(expr.expr), **kwargs) + nested_result = self._do_interpret(expr.expr, **kwargs) + return self._cg.function_invocation( + self.log1p_function_name, nested_result) + def interpret_sqrt_expr(self, expr, **kwargs): self.with_math_module = True if self.sqrt_function_name is NotImplemented: diff --git a/m2cgen/interpreters/java/interpreter.py b/m2cgen/interpreters/java/interpreter.py index 367bbb68..1951a42d 100644 --- a/m2cgen/interpreters/java/interpreter.py +++ b/m2cgen/interpreters/java/interpreter.py @@ -26,6 +26,8 @@ class JavaInterpreter(ImperativeToCodeInterpreter, abs_function_name = "Math.abs" exponent_function_name = "Math.exp" + logarithm_function_name = "Math.log" + log1p_function_name = "Math.log1p" power_function_name = "Math.pow" sqrt_function_name = "Math.sqrt" tanh_function_name = "Math.tanh" diff --git a/m2cgen/interpreters/javascript/interpreter.py b/m2cgen/interpreters/javascript/interpreter.py index fee3d774..383c41c4 100644 --- a/m2cgen/interpreters/javascript/interpreter.py +++ b/m2cgen/interpreters/javascript/interpreter.py @@ -21,6 +21,8 @@ class JavascriptInterpreter(ImperativeToCodeInterpreter, abs_function_name = "Math.abs" exponent_function_name = "Math.exp" + logarithm_function_name = "Math.log" + log1p_function_name = "Math.log1p" power_function_name = "Math.pow" sqrt_function_name = "Math.sqrt" tanh_function_name = "Math.tanh" diff --git a/m2cgen/interpreters/php/interpreter.py b/m2cgen/interpreters/php/interpreter.py index de5f496b..814b2b59 100644 --- a/m2cgen/interpreters/php/interpreter.py +++ b/m2cgen/interpreters/php/interpreter.py @@ -19,6 +19,8 @@ class PhpInterpreter(ImperativeToCodeInterpreter, abs_function_name = "abs" exponent_function_name = "exp" + logarithm_function_name = "log" + log1p_function_name = "log1p" power_function_name = "pow" sqrt_function_name = "sqrt" tanh_function_name = "tanh" diff --git a/m2cgen/interpreters/powershell/interpreter.py b/m2cgen/interpreters/powershell/interpreter.py index 7715cd20..6c220c50 100644 --- a/m2cgen/interpreters/powershell/interpreter.py +++ b/m2cgen/interpreters/powershell/interpreter.py @@ -20,10 +20,14 @@ class PowershellInterpreter(ImperativeToCodeInterpreter, abs_function_name = "[math]::Abs" exponent_function_name = "[math]::Exp" + logarithm_function_name = "[math]::Log" + log1p_function_name = "Log1p" power_function_name = "[math]::Pow" sqrt_function_name = "[math]::Sqrt" tanh_function_name = "[math]::Tanh" + with_log1p_expr = False + def __init__(self, indent=4, function_name="Score", *args, **kwargs): self.function_name = function_name @@ -47,6 +51,11 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.ps1") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_log1p_expr: + filename = os.path.join( + os.path.dirname(__file__), "log1p.ps1") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + return self._cg.finalize_and_get_generated_code() def interpret_abs_expr(self, expr, **kwargs): @@ -59,6 +68,15 @@ def interpret_exp_expr(self, expr, **kwargs): return self._cg.math_function_invocation( self.exponent_function_name, nested_result) + def interpret_log_expr(self, expr, **kwargs): + nested_result = self._do_interpret(expr.expr, **kwargs) + return self._cg.math_function_invocation( + self.logarithm_function_name, nested_result) + + def interpret_log1p_expr(self, expr, **kwargs): + self.with_log1p_expr = True + return super().interpret_log1p_expr(expr, **kwargs) + def interpret_sqrt_expr(self, expr, **kwargs): nested_result = self._do_interpret(expr.expr, **kwargs) return self._cg.math_function_invocation( diff --git a/m2cgen/interpreters/powershell/log1p.ps1 b/m2cgen/interpreters/powershell/log1p.ps1 new file mode 100644 index 00000000..d4e61e36 --- /dev/null +++ b/m2cgen/interpreters/powershell/log1p.ps1 @@ -0,0 +1,48 @@ +function Log1p([double] $x) { + if ($x -eq 0.0) { return 0.0 } + if ($x -eq -1.0) { return [double]::NegativeInfinity } + if ($x -lt -1.0) { return [double]::NaN } + [double] $xAbs = [math]::Abs($x) + if ($xAbs -lt 0.5 * [double]::Epsilon) { return $x } + if ((($x -gt 0.0) -and ($x -lt 1e-8)) + -or (($x -gt -1e-9) -and ($x -lt 0.0))) { + return $x * (1.0 - $x * 0.5) + } + if ($xAbs -lt 0.375) { + [double[]] $coeffs = @( + 0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15) + return $x * (1.0 - $x * (Chebyshev-Broucke ($x / 0.375) $coeffs)) + } + return [math]::Log(1.0 + $x) +} +function Chebyshev-Broucke([double] $x, [double[]] $coeffs) { + [double] $b2 = [double] $b1 = [double] $b0 = 0.0 + [double] $x2 = $x * 2 + for ([int] $i = $coeffs.Length - 1; $i -ge 0; --$i) { + $b2 = $b1 + $b1 = $b0 + $b0 = $x2 * $b1 - $b2 + $coeffs[$i] + } + return ($b0 - $b2) * 0.5 +} diff --git a/m2cgen/interpreters/python/interpreter.py b/m2cgen/interpreters/python/interpreter.py index 1e3b5d4f..a32d6dea 100644 --- a/m2cgen/interpreters/python/interpreter.py +++ b/m2cgen/interpreters/python/interpreter.py @@ -23,6 +23,8 @@ class PythonInterpreter(ImperativeToCodeInterpreter, abs_function_name = "abs" exponent_function_name = "math.exp" + logarithm_function_name = "math.log" + log1p_function_name = "math.log1p" power_function_name = "math.pow" sqrt_function_name = "math.sqrt" tanh_function_name = "math.tanh" diff --git a/m2cgen/interpreters/r/interpreter.py b/m2cgen/interpreters/r/interpreter.py index d8ba6245..c82fdf2b 100644 --- a/m2cgen/interpreters/r/interpreter.py +++ b/m2cgen/interpreters/r/interpreter.py @@ -24,6 +24,8 @@ class RInterpreter(ImperativeToCodeInterpreter, abs_function_name = "abs" exponent_function_name = "exp" + logarithm_function_name = "log" + log1p_function_name = "log1p" sqrt_function_name = "sqrt" tanh_function_name = "tanh" diff --git a/m2cgen/interpreters/ruby/interpreter.py b/m2cgen/interpreters/ruby/interpreter.py index 3004376c..909cd57d 100644 --- a/m2cgen/interpreters/ruby/interpreter.py +++ b/m2cgen/interpreters/ruby/interpreter.py @@ -19,9 +19,13 @@ class RubyInterpreter(ImperativeToCodeInterpreter, abs_function_name = "abs" exponent_function_name = "Math.exp" + logarithm_function_name = "Math.log" + log1p_function_name = "log1p" sqrt_function_name = "Math.sqrt" tanh_function_name = "Math.tanh" + with_log1p_expr = False + def __init__(self, indent=4, function_name="score", *args, **kwargs): self.function_name = function_name @@ -43,6 +47,11 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.rb") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_log1p_expr: + filename = os.path.join( + os.path.dirname(__file__), "log1p.rb") + self._cg.add_code_lines(utils.get_file_content(filename)) + return self._cg.finalize_and_get_generated_code() def interpret_bin_num_expr(self, expr, **kwargs): @@ -66,3 +75,7 @@ def interpret_pow_expr(self, expr, **kwargs): exp_result = self._do_interpret(expr.exp_expr, **kwargs) return self._cg.infix_expression( left=base_result, right=exp_result, op="**") + + def interpret_log1p_expr(self, expr, **kwargs): + self.with_log1p_expr = True + return super().interpret_log1p_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/ruby/log1p.rb b/m2cgen/interpreters/ruby/log1p.rb new file mode 100644 index 00000000..89c1cbc2 --- /dev/null +++ b/m2cgen/interpreters/ruby/log1p.rb @@ -0,0 +1,55 @@ +def log1p(x) + if x == 0.0 + return 0.0 + end + if x == -1.0 + return -Float::INFINITY + end + if x < -1.0 + return Float::NAN + end + x_abs = x.abs + if x_abs < 0.5 * Float::EPSILON + return x + end + if (x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0) + return x * (1.0 - x * 0.5) + end + if x_abs < 0.375 + coeffs = [ + 0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15] + return x * (1.0 - x * chebyshev_broucke(x / 0.375, coeffs)) + end + return Math.log(1.0 + x) +end +def chebyshev_broucke(x, coeffs) + b2 = b1 = b0 = 0.0 + x2 = x * 2 + coeffs.reverse_each do |i| + b2 = b1 + b1 = b0 + b0 = x2 * b1 - b2 + i + end + (b0 - b2) * 0.5 +end diff --git a/m2cgen/interpreters/visual_basic/interpreter.py b/m2cgen/interpreters/visual_basic/interpreter.py index 7884fee2..1a845e89 100644 --- a/m2cgen/interpreters/visual_basic/interpreter.py +++ b/m2cgen/interpreters/visual_basic/interpreter.py @@ -19,8 +19,11 @@ class VisualBasicInterpreter(ImperativeToCodeInterpreter, abs_function_name = "Math.Abs" exponent_function_name = "Math.Exp" + logarithm_function_name = "Math.Log" + log1p_function_name = "Log1p" tanh_function_name = "Tanh" + with_log1p_expr = False with_tanh_expr = False def __init__(self, module_name="Model", indent=4, function_name="Score", @@ -56,6 +59,11 @@ def interpret(self, expr): os.path.dirname(__file__), "tanh.bas") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_log1p_expr: + filename = os.path.join( + os.path.dirname(__file__), "log1p.bas") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + self._cg.prepend_code_line(self._cg.tpl_module_definition( module_name=self.module_name)) self._cg.add_code_line(self._cg.tpl_block_termination( @@ -69,7 +77,10 @@ def interpret_pow_expr(self, expr, **kwargs): return self._cg.infix_expression( left=base_result, right=exp_result, op="^") + def interpret_log1p_expr(self, expr, **kwargs): + self.with_log1p_expr = True + return super().interpret_log1p_expr(expr, **kwargs) + def interpret_tanh_expr(self, expr, **kwargs): self.with_tanh_expr = True - return super( - VisualBasicInterpreter, self).interpret_tanh_expr(expr, **kwargs) + return super().interpret_tanh_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/visual_basic/log1p.bas b/m2cgen/interpreters/visual_basic/log1p.bas new file mode 100644 index 00000000..a079c148 --- /dev/null +++ b/m2cgen/interpreters/visual_basic/log1p.bas @@ -0,0 +1,72 @@ +Function ChebyshevBroucke(ByVal x As Double, _ + ByRef coeffs() As Double) As Double + Dim b2 as Double + Dim b1 as Double + Dim b0 as Double + Dim x2 as Double + b2 = 0.0 + b1 = 0.0 + b0 = 0.0 + x2 = x * 2 + Dim i as Integer + For i = UBound(coeffs) - 1 To 0 Step -1 + b2 = b1 + b1 = b0 + b0 = x2 * b1 - b2 + coeffs(i) + Next i + ChebyshevBroucke = (b0 - b2) * 0.5 +End Function +Function Log1p(ByVal x As Double) As Double + If x = 0.0 Then + Log1p = 0.0 + Exit Function + End If + If x = -1.0 Then + On Error Resume Next + Log1p = -1.0 / 0.0 + Exit Function + End If + If x < -1.0 Then + On Error Resume Next + Log1p = 0.0 / 0.0 + Exit Function + End If + Dim xAbs As Double + xAbs = Math.Abs(x) + If xAbs < 0.5 * 4.94065645841247e-324 Then + Log1p = x + Exit Function + End If + If (x > 0.0 AND x < 1e-8) OR (x > -1e-9 AND x < 0.0) Then + Log1p = x * (1.0 - x * 0.5) + Exit Function + End If + If xAbs < 0.375 Then + Dim coeffs(22) As Double + coeffs(0) = 0.10378693562743769800686267719098e+1 + coeffs(1) = -0.13364301504908918098766041553133e+0 + coeffs(2) = 0.19408249135520563357926199374750e-1 + coeffs(3) = -0.30107551127535777690376537776592e-2 + coeffs(4) = 0.48694614797154850090456366509137e-3 + coeffs(5) = -0.81054881893175356066809943008622e-4 + coeffs(6) = 0.13778847799559524782938251496059e-4 + coeffs(7) = -0.23802210894358970251369992914935e-5 + coeffs(8) = 0.41640416213865183476391859901989e-6 + coeffs(9) = -0.73595828378075994984266837031998e-7 + coeffs(10) = 0.13117611876241674949152294345011e-7 + coeffs(11) = -0.23546709317742425136696092330175e-8 + coeffs(12) = 0.42522773276034997775638052962567e-9 + coeffs(13) = -0.77190894134840796826108107493300e-10 + coeffs(14) = 0.14075746481359069909215356472191e-10 + coeffs(15) = -0.25769072058024680627537078627584e-11 + coeffs(16) = 0.47342406666294421849154395005938e-12 + coeffs(17) = -0.87249012674742641745301263292675e-13 + coeffs(18) = 0.16124614902740551465739833119115e-13 + coeffs(19) = -0.29875652015665773006710792416815e-14 + coeffs(20) = 0.55480701209082887983041321697279e-15 + coeffs(21) = -0.10324619158271569595141333961932e-15 + Log1p = x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs)) + Exit Function + End If + Log1p = Math.log(1.0 + x) +End Function diff --git a/tests/interpreters/test_c.py b/tests/interpreters/test_c.py index c6e0ae79..86db19d9 100644 --- a/tests/interpreters/test_c.py +++ b/tests/interpreters/test_c.py @@ -282,6 +282,34 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + interpreter = interpreters.CInterpreter() + + expected_code = """ +#include +double score(double * input) { + return log(2.0); +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + interpreter = interpreters.CInterpreter() + + expected_code = """ +#include +double score(double * input) { + return log1p(2.0); +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_c_sharp.py b/tests/interpreters/test_c_sharp.py index 6a6878ff..963c146b 100644 --- a/tests/interpreters/test_c_sharp.py +++ b/tests/interpreters/test_c_sharp.py @@ -383,6 +383,93 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + expected_code = """ +using static System.Math; +namespace ML { + public static class Model { + public static double Score(double[] input) { + return Log(2.0); + } + } +} +""" + + interpreter = CSharpInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + expected_code = """ +using static System.Math; +namespace ML { + public static class Model { + public static double Score(double[] input) { + return Log1p(2.0); + } + private static double Log1p(double x) { + if (x == 0.0) + return 0.0; + if (x == -1.0) + return double.NegativeInfinity; + if (x < -1.0) + return double.NaN; + double xAbs = Abs(x); + if (xAbs < 0.5 * double.Epsilon) + return x; + if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0)) + return x * (1.0 - x * 0.5); + if (xAbs < 0.375) { + double[] coeffs = { + 0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15}; + return x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs)); + } + return Log(1.0 + x); + } + private static double ChebyshevBroucke(double x, double[] coeffs) { + double b0, b1, b2, x2; + b2 = b1 = b0 = 0.0; + x2 = x * 2; + for (int i = coeffs.Length - 1; i >= 0; --i) { + b2 = b1; + b1 = b0; + b0 = x2 * b1 - b2 + coeffs[i]; + } + return (b0 - b2) * 0.5; + } + } +} +""" + + interpreter = CSharpInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_dart.py b/tests/interpreters/test_dart.py index 493c059a..09ac2a6a 100644 --- a/tests/interpreters/test_dart.py +++ b/tests/interpreters/test_dart.py @@ -476,6 +476,83 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + expected_code = """ +import 'dart:math'; +double score(List input) { + return log(2.0); +} +""" + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + expected_code = """ +import 'dart:math'; +double score(List input) { + return log1p(2.0); +} +double log1p(double x) { + if (x == 0.0) + return 0.0; + if (x == -1.0) + return double.negativeInfinity; + if (x < -1.0) + return double.nan; + double xAbs = x.abs(); + if (xAbs < 0.5 * 4.94065645841247e-324) + return x; + if ((x > 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0)) + return x * (1.0 - x * 0.5); + if (xAbs < 0.375) { + List coeffs = [ + 0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15]; + return x * (1.0 - x * chebyshevBroucke(x / 0.375, coeffs)); + } + return log(1.0 + x); +} +double chebyshevBroucke(double x, List coeffs) { + double b0, b1, b2, x2; + b2 = b1 = b0 = 0.0; + x2 = x * 2; + for (int i = coeffs.length - 1; i >= 0; --i) { + b2 = b1; + b1 = b0; + b0 = x2 * b1 - b2 + coeffs[i]; + } + return (b0 - b2) * 0.5; +} +""" + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_go.py b/tests/interpreters/test_go.py index c6f976c6..69b8cc49 100644 --- a/tests/interpreters/test_go.py +++ b/tests/interpreters/test_go.py @@ -285,6 +285,34 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + interpreter = interpreters.GoInterpreter() + + expected_code = """ +import "math" +func score(input []float64) float64 { + return math.Log(2.0) +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + interpreter = interpreters.GoInterpreter() + + expected_code = """ +import "math" +func score(input []float64) float64 { + return math.Log1p(2.0) +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_haskell.py b/tests/interpreters/test_haskell.py index eb336850..5b41e666 100644 --- a/tests/interpreters/test_haskell.py +++ b/tests/interpreters/test_haskell.py @@ -285,6 +285,74 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + expected_code = """ +module Model where +score :: [Double] -> Double +score input = + log (2.0) +""" + + interpreter = HaskellInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + expected_code = """ +module Model where +log1p :: Double -> Double +log1p x + | x == 0 = 0 + | x == -1 = -1 / 0 + | x < -1 = 0 / 0 + | x' < m_epsilon * 0.5 = x + | (x > 0 && x < 1e-8) || (x > -1e-9 && x < 0) + = x * (1 - x * 0.5) + | x' < 0.375 = x * (1 - x * chebyshevBroucke (x / 0.375) coeffs) + | otherwise = log (1 + x) + where + m_epsilon = encodeFloat (signif + 1) expo - 1.0 + where (signif, expo) = decodeFloat (1.0::Double) + x' = abs x + coeffs = [0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15] + chebyshevBroucke i = fini . foldr step [0, 0, 0] + where + step k [b0, b1, _] = [(k + i * 2 * b0 - b1), b0, b1] + fini [b0, _, b2] = (b0 - b2) * 0.5 +score :: [Double] -> Double +score input = + log1p (2.0) +""" + + interpreter = HaskellInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_java.py b/tests/interpreters/test_java.py index b7e5bdab..e0de4813 100644 --- a/tests/interpreters/test_java.py +++ b/tests/interpreters/test_java.py @@ -347,6 +347,36 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + interpreter = interpreters.JavaInterpreter() + + expected_code = """ +public class Model { + public static double score(double[] input) { + return Math.log(2.0); + } +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + interpreter = interpreters.JavaInterpreter() + + expected_code = """ +public class Model { + public static double score(double[] input) { + return Math.log1p(2.0); + } +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_javascript.py b/tests/interpreters/test_javascript.py index e3a8e24d..8311fda1 100644 --- a/tests/interpreters/test_javascript.py +++ b/tests/interpreters/test_javascript.py @@ -300,6 +300,34 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + interpreter = interpreters.JavascriptInterpreter() + + expected_code = """ +function score(input) { + return Math.log(2.0); +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + interpreter = interpreters.JavascriptInterpreter() + + expected_code = """ +function score(input) { + return Math.log1p(2.0); +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_php.py b/tests/interpreters/test_php.py index ddacccf0..dcf2edbf 100644 --- a/tests/interpreters/test_php.py +++ b/tests/interpreters/test_php.py @@ -305,6 +305,34 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + expected_code = """ + 0.0 && x < 1e-8) || (x > -1e-9 && x < 0.0) + return x * (1.0 - x * 0.5) + end + if x_abs < 0.375 + coeffs = [ + 0.10378693562743769800686267719098e+1, + -0.13364301504908918098766041553133e+0, + 0.19408249135520563357926199374750e-1, + -0.30107551127535777690376537776592e-2, + 0.48694614797154850090456366509137e-3, + -0.81054881893175356066809943008622e-4, + 0.13778847799559524782938251496059e-4, + -0.23802210894358970251369992914935e-5, + 0.41640416213865183476391859901989e-6, + -0.73595828378075994984266837031998e-7, + 0.13117611876241674949152294345011e-7, + -0.23546709317742425136696092330175e-8, + 0.42522773276034997775638052962567e-9, + -0.77190894134840796826108107493300e-10, + 0.14075746481359069909215356472191e-10, + -0.25769072058024680627537078627584e-11, + 0.47342406666294421849154395005938e-12, + -0.87249012674742641745301263292675e-13, + 0.16124614902740551465739833119115e-13, + -0.29875652015665773006710792416815e-14, + 0.55480701209082887983041321697279e-15, + -0.10324619158271569595141333961932e-15] + return x * (1.0 - x * chebyshev_broucke(x / 0.375, coeffs)) + end + return Math.log(1.0 + x) +end +def chebyshev_broucke(x, coeffs) + b2 = b1 = b0 = 0.0 + x2 = x * 2 + coeffs.reverse_each do |i| + b2 = b1 + b1 = b0 + b0 = x2 * b1 - b2 + i + end + (b0 - b2) * 0.5 +end +""" + + interpreter = RubyInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_visual_basic.py b/tests/interpreters/test_visual_basic.py index a6434003..c277a536 100644 --- a/tests/interpreters/test_visual_basic.py +++ b/tests/interpreters/test_visual_basic.py @@ -410,6 +410,108 @@ def test_tanh_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_log_expr(): + expr = ast.LogExpr(ast.NumVal(2.0)) + + expected_code = """ +Module Model +Function Score(ByRef inputVector() As Double) As Double + Score = Math.Log(2.0) +End Function +End Module +""" + + interpreter = VisualBasicInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + expected_code = """ +Module Model +Function ChebyshevBroucke(ByVal x As Double, _ + ByRef coeffs() As Double) As Double + Dim b2 as Double + Dim b1 as Double + Dim b0 as Double + Dim x2 as Double + b2 = 0.0 + b1 = 0.0 + b0 = 0.0 + x2 = x * 2 + Dim i as Integer + For i = UBound(coeffs) - 1 To 0 Step -1 + b2 = b1 + b1 = b0 + b0 = x2 * b1 - b2 + coeffs(i) + Next i + ChebyshevBroucke = (b0 - b2) * 0.5 +End Function +Function Log1p(ByVal x As Double) As Double + If x = 0.0 Then + Log1p = 0.0 + Exit Function + End If + If x = -1.0 Then + On Error Resume Next + Log1p = -1.0 / 0.0 + Exit Function + End If + If x < -1.0 Then + On Error Resume Next + Log1p = 0.0 / 0.0 + Exit Function + End If + Dim xAbs As Double + xAbs = Math.Abs(x) + If xAbs < 0.5 * 4.94065645841247e-324 Then + Log1p = x + Exit Function + End If + If (x > 0.0 AND x < 1e-8) OR (x > -1e-9 AND x < 0.0) Then + Log1p = x * (1.0 - x * 0.5) + Exit Function + End If + If xAbs < 0.375 Then + Dim coeffs(22) As Double + coeffs(0) = 0.10378693562743769800686267719098e+1 + coeffs(1) = -0.13364301504908918098766041553133e+0 + coeffs(2) = 0.19408249135520563357926199374750e-1 + coeffs(3) = -0.30107551127535777690376537776592e-2 + coeffs(4) = 0.48694614797154850090456366509137e-3 + coeffs(5) = -0.81054881893175356066809943008622e-4 + coeffs(6) = 0.13778847799559524782938251496059e-4 + coeffs(7) = -0.23802210894358970251369992914935e-5 + coeffs(8) = 0.41640416213865183476391859901989e-6 + coeffs(9) = -0.73595828378075994984266837031998e-7 + coeffs(10) = 0.13117611876241674949152294345011e-7 + coeffs(11) = -0.23546709317742425136696092330175e-8 + coeffs(12) = 0.42522773276034997775638052962567e-9 + coeffs(13) = -0.77190894134840796826108107493300e-10 + coeffs(14) = 0.14075746481359069909215356472191e-10 + coeffs(15) = -0.25769072058024680627537078627584e-11 + coeffs(16) = 0.47342406666294421849154395005938e-12 + coeffs(17) = -0.87249012674742641745301263292675e-13 + coeffs(18) = 0.16124614902740551465739833119115e-13 + coeffs(19) = -0.29875652015665773006710792416815e-14 + coeffs(20) = 0.55480701209082887983041321697279e-15 + coeffs(21) = -0.10324619158271569595141333961932e-15 + Log1p = x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs)) + Exit Function + End If + Log1p = Math.log(1.0 + x) +End Function +Function Score(ByRef inputVector() As Double) As Double + Score = Log1p(2.0) +End Function +End Module +""" + + interpreter = VisualBasicInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/test_ast.py b/tests/test_ast.py index 2155697d..db3523ee 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -48,6 +48,8 @@ def test_count_exprs_exclude_list(): ast.VectorVal([ ast.AbsExpr(ast.NumVal(-2)), ast.ExpExpr(ast.NumVal(2)), + ast.LogExpr(ast.NumVal(2)), + ast.Log1pExpr(ast.NumVal(2)), ast.SqrtExpr(ast.NumVal(2)), ast.PowExpr(ast.NumVal(2), ast.NumVal(3)), ast.TanhExpr(ast.NumVal(1)), @@ -63,6 +65,8 @@ def test_count_exprs_exclude_list(): ast.NumVal(3), ast.NumVal(4), ast.NumVal(5), + ast.NumVal(6), + ast.NumVal(7), ast.FeatureRef(1) ])), ast.BinNumOpType.SUB), @@ -75,7 +79,7 @@ def test_count_exprs_exclude_list(): def test_count_all_exprs_types(): - assert ast.count_exprs(EXPR_WITH_ALL_EXPRS) == 31 + assert ast.count_exprs(EXPR_WITH_ALL_EXPRS) == 37 def test_exprs_equality(): diff --git a/tests/test_fallback_expressions.py b/tests/test_fallback_expressions.py index 111a4c76..1783ede7 100644 --- a/tests/test_fallback_expressions.py +++ b/tests/test_fallback_expressions.py @@ -79,3 +79,25 @@ def score(input): """ assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_log1p_fallback_expr(): + expr = ast.Log1pExpr(ast.NumVal(2.0)) + + interpreter = PythonInterpreter() + interpreter.log1p_function_name = NotImplemented + + expected_code = """ +import math +def score(input): + var1 = 2.0 + var2 = (1.0) + (var1) + var3 = (var2) - (1.0) + if (var3) == (0.0): + var0 = var1 + else: + var0 = ((var1) * (math.log(var2))) / (var3) + return var0 +""" + + assert_code_equal(interpreter.interpret(expr), expected_code)