Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added fallback expression for Tanh function #199

Merged
merged 10 commits into from
May 18, 2020
Merged

added fallback expression for Tanh function #199

merged 10 commits into from
May 18, 2020

Conversation

StrikerRUS
Copy link
Member

@StrikerRUS StrikerRUS commented Apr 30, 2020

Contributes to #196.

Now new languages may not implement or import Tanh function, but rely on our fallback expression (but it is NOT RECOMMENDED).

All current languages were tested with increased test data fraction to 0.2: https://travis-ci.org/github/BayesWitnesses/m2cgen/builds/681243574.

Also, here is a proof that new fallback expression works fine:
comment this line

tanh_function_name = "math.tanh"

and run

from math import tanh

import numpy as np
import m2cgen as m2c

interpreter = m2c.interpreters.PythonInterpreter()

for i in np.arange(-50, 50, 0.001):
    tanh_expr = m2c.ast.TanhExpr(m2c.ast.NumVal(i))
    exec(interpreter.interpret(tanh_expr))
    np.testing.assert_allclose(tanh(i), score([]), rtol=1e-10)

Below are examples of generated code for Java and Haskell when Tanh function was not provided

from sklearn.datasets import load_boston
from sklearn.svm import SVR

import m2cgen as m2c

X, y = load_boston(True)
X = X[:10, :4]
y = y[:10]

est = SVR(kernel='sigmoid').fit(X, y)

print(m2c.export_to_java(est))
public class Model {

    public static double score(double[] input) {
        double var0;
        double var1;
        var1 = ((0.010678114199077144) * (((((0.00632) * (input[0])) + ((18.0) * (input[1]))) + ((2.31) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var1) > (44.0)) {
            var0 = 1.0;
        } else {
            if ((var1) < (-44.0)) {
                var0 = -1.0;
            } else {
                var0 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var1))) + (1.0)));
            }
        }
        double var2;
        double var3;
        var3 = ((0.010678114199077144) * (((((0.02731) * (input[0])) + ((0.0) * (input[1]))) + ((7.07) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var3) > (44.0)) {
            var2 = 1.0;
        } else {
            if ((var3) < (-44.0)) {
                var2 = -1.0;
            } else {
                var2 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var3))) + (1.0)));
            }
        }
        double var4;
        double var5;
        var5 = ((0.010678114199077144) * (((((0.02729) * (input[0])) + ((0.0) * (input[1]))) + ((7.07) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var5) > (44.0)) {
            var4 = 1.0;
        } else {
            if ((var5) < (-44.0)) {
                var4 = -1.0;
            } else {
                var4 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var5))) + (1.0)));
            }
        }
        double var6;
        double var7;
        var7 = ((0.010678114199077144) * (((((0.03237) * (input[0])) + ((0.0) * (input[1]))) + ((2.18) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var7) > (44.0)) {
            var6 = 1.0;
        } else {
            if ((var7) < (-44.0)) {
                var6 = -1.0;
            } else {
                var6 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var7))) + (1.0)));
            }
        }
        double var8;
        double var9;
        var9 = ((0.010678114199077144) * (((((0.06905) * (input[0])) + ((0.0) * (input[1]))) + ((2.18) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var9) > (44.0)) {
            var8 = 1.0;
        } else {
            if ((var9) < (-44.0)) {
                var8 = -1.0;
            } else {
                var8 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var9))) + (1.0)));
            }
        }
        double var10;
        double var11;
        var11 = ((0.010678114199077144) * (((((0.02985) * (input[0])) + ((0.0) * (input[1]))) + ((2.18) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var11) > (44.0)) {
            var10 = 1.0;
        } else {
            if ((var11) < (-44.0)) {
                var10 = -1.0;
            } else {
                var10 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var11))) + (1.0)));
            }
        }
        double var12;
        double var13;
        var13 = ((0.010678114199077144) * (((((0.08829) * (input[0])) + ((12.5) * (input[1]))) + ((7.87) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var13) > (44.0)) {
            var12 = 1.0;
        } else {
            if ((var13) < (-44.0)) {
                var12 = -1.0;
            } else {
                var12 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var13))) + (1.0)));
            }
        }
        double var14;
        double var15;
        var15 = ((0.010678114199077144) * (((((0.14455) * (input[0])) + ((12.5) * (input[1]))) + ((7.87) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var15) > (44.0)) {
            var14 = 1.0;
        } else {
            if ((var15) < (-44.0)) {
                var14 = -1.0;
            } else {
                var14 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var15))) + (1.0)));
            }
        }
        double var16;
        double var17;
        var17 = ((0.010678114199077144) * (((((0.21124) * (input[0])) + ((12.5) * (input[1]))) + ((7.87) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var17) > (44.0)) {
            var16 = 1.0;
        } else {
            if ((var17) < (-44.0)) {
                var16 = -1.0;
            } else {
                var16 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var17))) + (1.0)));
            }
        }
        double var18;
        double var19;
        var19 = ((0.010678114199077144) * (((((0.17004) * (input[0])) + ((12.5) * (input[1]))) + ((7.87) * (input[2]))) + ((0.0) * (input[3])))) + (0.0);
        if ((var19) > (44.0)) {
            var18 = 1.0;
        } else {
            if ((var19) < (-44.0)) {
                var18 = -1.0;
            } else {
                var18 = (1.0) - ((2.0) / ((Math.exp((2.0) * (var19))) + (1.0)));
            }
        }
        return ((((((((((27.889501992985608) + ((var0) * (-1.0))) + ((var2) * (-1.0))) + ((var4) * (1.0))) + ((var6) * (1.0))) + ((var8) * (1.0))) + ((var10) * (1.0))) + ((var12) * (-1.0))) + ((var14) * (1.0))) + ((var16) * (-1.0))) + ((var18) * (-1.0));
    }
}
module Model where
score :: [Double] -> Double
score input =
    ((((((((((27.889501992985608) + ((func1) * (-1.0))) + ((func3) * (-1.0))) + ((func5) * (1.0))) + ((func7) * (1.0))) + ((func9) * (1.0))) + ((func11) * (1.0))) + ((func13) * (-1.0))) + ((func15) * (1.0))) + ((func17) * (-1.0))) + ((func19) * (-1.0))
    where
        func0 =
            ((0.010678114199077144) * (((((0.00632) * ((input) !! (0))) + ((18.0) * ((input) !! (1)))) + ((2.31) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func1 =
            if ((func0) > (44.0))
                then
                    1.0
                else
                    if ((func0) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func0))) + (1.0)))
        func2 =
            ((0.010678114199077144) * (((((0.02731) * ((input) !! (0))) + ((0.0) * ((input) !! (1)))) + ((7.07) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func3 =
            if ((func2) > (44.0))
                then
                    1.0
                else
                    if ((func2) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func2))) + (1.0)))
        func4 =
            ((0.010678114199077144) * (((((0.02729) * ((input) !! (0))) + ((0.0) * ((input) !! (1)))) + ((7.07) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func5 =
            if ((func4) > (44.0))
                then
                    1.0
                else
                    if ((func4) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func4))) + (1.0)))
        func6 =
            ((0.010678114199077144) * (((((0.03237) * ((input) !! (0))) + ((0.0) * ((input) !! (1)))) + ((2.18) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func7 =
            if ((func6) > (44.0))
                then
                    1.0
                else
                    if ((func6) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func6))) + (1.0)))
        func8 =
            ((0.010678114199077144) * (((((0.06905) * ((input) !! (0))) + ((0.0) * ((input) !! (1)))) + ((2.18) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func9 =
            if ((func8) > (44.0))
                then
                    1.0
                else
                    if ((func8) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func8))) + (1.0)))
        func10 =
            ((0.010678114199077144) * (((((0.02985) * ((input) !! (0))) + ((0.0) * ((input) !! (1)))) + ((2.18) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func11 =
            if ((func10) > (44.0))
                then
                    1.0
                else
                    if ((func10) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func10))) + (1.0)))
        func12 =
            ((0.010678114199077144) * (((((0.08829) * ((input) !! (0))) + ((12.5) * ((input) !! (1)))) + ((7.87) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func13 =
            if ((func12) > (44.0))
                then
                    1.0
                else
                    if ((func12) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func12))) + (1.0)))
        func14 =
            ((0.010678114199077144) * (((((0.14455) * ((input) !! (0))) + ((12.5) * ((input) !! (1)))) + ((7.87) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func15 =
            if ((func14) > (44.0))
                then
                    1.0
                else
                    if ((func14) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func14))) + (1.0)))
        func16 =
            ((0.010678114199077144) * (((((0.21124) * ((input) !! (0))) + ((12.5) * ((input) !! (1)))) + ((7.87) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func17 =
            if ((func16) > (44.0))
                then
                    1.0
                else
                    if ((func16) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func16))) + (1.0)))
        func18 =
            ((0.010678114199077144) * (((((0.17004) * ((input) !! (0))) + ((12.5) * ((input) !! (1)))) + ((7.87) * ((input) !! (2)))) + ((0.0) * ((input) !! (3))))) + (0.0)
        func19 =
            if ((func18) > (44.0))
                then
                    1.0
                else
                    if ((func18) < (-44.0))
                        then
                            -1.0
                        else
                            (1.0) - ((2.0) / ((exp ((2.0) * (func18))) + (1.0)))

// https://github.com/golang/go/blob/master/src/math/tanh.go
double z;
z = x.abs();
if (z > 44.0148459655565271479942397125) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -128,28 +129,33 @@ def interpret_vector_val(self, expr, **kwargs):
return self._cg.vector_init(nested)

def interpret_exp_expr(self, expr, **kwargs):
assert self.exponent_function_name, "Exponent function is not provided"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous asserts didn't work.

test = NotImplemented
assert test, "Alert!!!"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch 👍

@coveralls
Copy link

Coverage Status

Coverage decreased (-0.5%) to 95.346% when pulling 2bd524a on exprs into f1caa20 on master.

@coveralls
Copy link

coveralls commented Apr 30, 2020

Coverage Status

Coverage decreased (-0.07%) to 95.459% when pulling 22ba2c7 on exprs into 3db1cc2 on master.

@StrikerRUS
Copy link
Member Author

Hmmm, again that error, but now for Python

=================================== FAILURES ===================================
_ test_e2e[xgboost_XGBClassifier - python - train_model_classification_binary2] _
estimator = XGBClassifier(base_score=0.6, booster='gblinear', colsample_bylevel=None,
              colsample_bynode=None, colsamp...ambda=0, scale_pos_weight=1, subsample=None,
              tree_method=None, validate_parameters=False, verbosity=None)
executor_cls = <class 'tests.e2e.executors.python.PythonExecutor'>

...

expected=[0.04805166 0.95194834], actual=[0.04805117682396598, 0.951948823176034]
expected=[0.06518847 0.93481153], actual=[0.06518802156571324, 0.9348119784342868]
expected=[0.12917638 0.8708236 ], actual=[0.12917560920779636, 0.8708243907922036]
expected=[0.07704622 0.9229538 ], actual=[0.07704532874494041, 0.9229546712550596]
expected=[0.8145962  0.18540382], actual=[0.8145932871925088, 0.18540671280749127]

@StrikerRUS StrikerRUS changed the title added fallback expression for Tanh function [WIP] added fallback expression for Tanh function Apr 30, 2020
@StrikerRUS
Copy link
Member Author

Marking WIP as I'm going to add a test to prevent the code coverage decrease.

@StrikerRUS StrikerRUS changed the title [WIP] added fallback expression for Tanh function added fallback expression for Tanh function May 1, 2020
@izeigerman
Copy link
Member

Wow, this is awesome! I'll take a look soon 👍

Copy link
Member

@izeigerman izeigerman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic job 👍 Opened a few topics for discussion.

@@ -0,0 +1,43 @@
from m2cgen import ast
from m2cgen.assemblers import utils

Copy link
Member

@izeigerman izeigerman May 7, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's worth leaving a comment or a pydoc explaining when these fallbacks are being applied. Something like:

This module provides an implementation for a variety of functions expressed in library's AST. 
These AST-based implementations are used as fallbacks in case when the target language lacks native support for respective functions provided in this module.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fault! Great docstring, will add it.



def tanh(expr):
expr.to_reuse = True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little bit concerned by the fact that here we implicitly alter the state of the expression passed as an input argument. I'm worrying that if we adopt this kind of practice going forward this may lead to some unexpected behaviors in future. What do you say if instead we introduce some kind of identity expression, eg. IdExpr. It may transparently wrap the original expression but have a reuse flag set true. What do you think of this concern in general?

Copy link
Member Author

@StrikerRUS StrikerRUS May 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you say if instead we introduce some kind of identity expression, eg. IdExpr.

Are you reading my thoughts?! 😄
The first thing that came to my mind was to use some dummy wrapper for the passed expression. And I searched ast.py for any but unfortunately didn't succeed in it.

I think it's better to add IdExpr in a separate PR to not overcomplicate this one. I'll prepare it.

def eq(l, r):
return ast.CompExpr(l, r, ast.CompOpType.EQ)


def ne(l, r):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function used? I couldn't find where

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope :-|
I added all possible comparison functions for the completeness. Do you think that unused functions should be removed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I'd argue that we should introduce code only if it's being used, even when it seems like unused code contributes into the perception of completeness. This is still an unused code. I'm also guilty of doing this in the past, btw :D I suggest we introduce these functions once we have application for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest we introduce these functions once we have application for them.

OK, no problem!

return ast.CompExpr(l, r, ast.CompOpType.GT)


def gte(l, r):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

@StrikerRUS StrikerRUS mentioned this pull request May 8, 2020
Copy link
Member

@izeigerman izeigerman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants