diff --git a/MANIFEST.in b/MANIFEST.in index 6a9fa8a4..06b0e1d3 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -4,4 +4,5 @@ recursive-include m2cgen linear_algebra.* recursive-include m2cgen log1p.* recursive-include m2cgen tanh.* recursive-include m2cgen atan.* +recursive-include m2cgen softmax.* global-exclude *.py[cod] diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index 99be8267..0469981e 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -72,8 +72,7 @@ def _assemble_multi_class_output(self, estimator_params): for i, e in enumerate(splits) ] - proba_exprs = self._multi_class_convert_output(exprs) - return ast.VectorVal(proba_exprs) + return self._multi_class_convert_output(exprs) def _assemble_bin_class_output(self, estimator_params): # Base score is calculated based on @@ -97,7 +96,7 @@ def _final_transform(self, ast_to_transform): return ast_to_transform def _multi_class_convert_output(self, exprs): - return fallback_expressions.softmax(exprs) + return ast.SoftmaxExpr(exprs) def _bin_class_convert_output(self, expr, to_reuse=True): return fallback_expressions.sigmoid(expr, to_reuse=to_reuse) @@ -250,8 +249,9 @@ def _multi_class_convert_output(self, exprs): return supported_objectives[self.objective_name](exprs) def _multi_class_sigmoid_transform(self, exprs): - return [self._bin_class_sigmoid_transform(expr, to_reuse=False) - for expr in exprs] + return ast.VectorVal([ + self._bin_class_sigmoid_transform(expr, to_reuse=False) + for expr in exprs]) def _bin_class_convert_output(self, expr, to_reuse=True): supported_objectives = { diff --git a/m2cgen/assemblers/fallback_expressions.py b/m2cgen/assemblers/fallback_expressions.py index e42329ed..56230484 100644 --- a/m2cgen/assemblers/fallback_expressions.py +++ b/m2cgen/assemblers/fallback_expressions.py @@ -190,7 +190,7 @@ def softmax(exprs): exp_exprs = [ast.ExpExpr(e, to_reuse=True) for e in exprs] exp_sum_expr = utils.apply_op_to_expressions( ast.BinNumOpType.ADD, *exp_exprs, to_reuse=True) - return [ + return ast.VectorVal([ ast.BinNumExpr(e, exp_sum_expr, ast.BinNumOpType.DIV) for e in exp_exprs - ] + ]) diff --git a/m2cgen/ast.py b/m2cgen/ast.py index 4c167677..87f090e8 100644 --- a/m2cgen/ast.py +++ b/m2cgen/ast.py @@ -266,6 +266,29 @@ def __hash__(self): return hash(tuple(self.exprs)) +class SoftmaxExpr(VectorExpr): + + def __init__(self, exprs, to_reuse=False): + assert all(e.output_size == 1 for e in exprs), ( + "All expressions for SoftmaxExpr must be scalar") + + self.exprs = exprs + self.to_reuse = to_reuse + self.output_size = len(exprs) + + def __str__(self): + args = ",".join([str(e) for e in self.exprs]) + return f"SoftmaxExpr({args},to_reuse={self.to_reuse})" + + def __eq__(self, other): + return (type(other) is SoftmaxExpr and + self.output_size == other.output_size and + all(i == j for i, j in zip(self.exprs, other.exprs))) + + def __hash__(self): + return hash(tuple(self.exprs)) + + class BinVectorExpr(VectorExpr, BinExpr): def __init__(self, left, right, op): @@ -384,9 +407,9 @@ def __hash__(self): NESTED_EXPRS_MAPPINGS = [ ((BinExpr, CompExpr), lambda e: [e.left, e.right]), - (PowExpr, lambda e: [e.base_expr, e.exp_expr]), - (VectorVal, lambda e: e.exprs), - (IfExpr, lambda e: [e.test, e.body, e.orelse]), + ((PowExpr), lambda e: [e.base_expr, e.exp_expr]), + ((VectorVal, SoftmaxExpr), lambda e: e.exprs), + ((IfExpr), lambda e: [e.test, e.body, e.orelse]), ((AbsExpr, AtanExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]), diff --git a/m2cgen/interpreters/c/interpreter.py b/m2cgen/interpreters/c/interpreter.py index a8e55c8a..9f5aa00a 100644 --- a/m2cgen/interpreters/c/interpreter.py +++ b/m2cgen/interpreters/c/interpreter.py @@ -23,9 +23,12 @@ class CInterpreter(ImperativeToCodeInterpreter, logarithm_function_name = "log" log1p_function_name = "log1p" power_function_name = "pow" + softmax_function_name = "softmax" sqrt_function_name = "sqrt" tanh_function_name = "tanh" + with_softmax_expr = False + def __init__(self, indent=4, function_name="score", *args, **kwargs): self.function_name = function_name @@ -61,6 +64,11 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.c") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.c") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_vectors: self._cg.add_dependency("") @@ -99,3 +107,18 @@ def interpret_bin_vector_num_expr(self, expr, **kwargs): self._cg.add_code_line(f"{func_inv};") return var_name + + # Do the same things for softmax as for linear algebra. + def interpret_softmax_expr(self, expr, **kwargs): + self.with_vectors = True + self.with_softmax_expr = True + + var_name = self._cg.add_var_declaration(expr.output_size) + nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs] + func_inv = self._cg.function_invocation( + self.softmax_function_name, + self._cg.vector_init(nested), + expr.output_size, + var_name) + self._cg.add_code_line(f"{func_inv};") + return var_name diff --git a/m2cgen/interpreters/c_sharp/interpreter.py b/m2cgen/interpreters/c_sharp/interpreter.py index 7a8b721c..17fc3363 100644 --- a/m2cgen/interpreters/c_sharp/interpreter.py +++ b/m2cgen/interpreters/c_sharp/interpreter.py @@ -24,10 +24,12 @@ class CSharpInterpreter(ImperativeToCodeInterpreter, logarithm_function_name = "Log" log1p_function_name = "Log1p" power_function_name = "Pow" + softmax_function_name = "Softmax" sqrt_function_name = "Sqrt" tanh_function_name = "Tanh" with_log1p_expr = False + with_softmax_expr = False def __init__(self, namespace="ML", class_name="Model", indent=4, function_name="Score", *args, **kwargs): @@ -66,6 +68,11 @@ def interpret(self, expr): os.path.dirname(__file__), "log1p.cs") self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.cs") + self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_math_module: self._cg.add_dependency("System.Math") @@ -74,3 +81,7 @@ def interpret(self, expr): def interpret_log1p_expr(self, expr, **kwargs): self.with_log1p_expr = True return super().interpret_log1p_expr(expr, **kwargs) + + def interpret_softmax_expr(self, expr, **kwargs): + self.with_softmax_expr = True + return super().interpret_softmax_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/dart/interpreter.py b/m2cgen/interpreters/dart/interpreter.py index 1ed22067..e72d8f11 100644 --- a/m2cgen/interpreters/dart/interpreter.py +++ b/m2cgen/interpreters/dart/interpreter.py @@ -27,10 +27,12 @@ class DartInterpreter(ImperativeToCodeInterpreter, logarithm_function_name = "log" log1p_function_name = "log1p" power_function_name = "pow" + softmax_function_name = "softmax" sqrt_function_name = "sqrt" tanh_function_name = "tanh" with_log1p_expr = False + with_softmax_expr = False with_tanh_expr = False def __init__(self, indent=4, function_name="score", *args, **kwargs): @@ -63,6 +65,11 @@ def interpret(self, expr): os.path.dirname(__file__), "log1p.dart") self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.dart") + self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_tanh_expr: filename = os.path.join( os.path.dirname(__file__), "tanh.dart") @@ -86,3 +93,7 @@ def interpret_log1p_expr(self, expr, **kwargs): def interpret_tanh_expr(self, expr, **kwargs): self.with_tanh_expr = True return super().interpret_tanh_expr(expr, **kwargs) + + def interpret_softmax_expr(self, expr, **kwargs): + self.with_softmax_expr = True + return super().interpret_softmax_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/f_sharp/interpreter.py b/m2cgen/interpreters/f_sharp/interpreter.py index 89300078..e9250d7b 100644 --- a/m2cgen/interpreters/f_sharp/interpreter.py +++ b/m2cgen/interpreters/f_sharp/interpreter.py @@ -30,10 +30,12 @@ class FSharpInterpreter(FunctionalToCodeInterpreter, exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" + softmax_function_name = "softmax" sqrt_function_name = "sqrt" tanh_function_name = "tanh" with_log1p_expr = False + with_softmax_expr = False def __init__(self, indent=4, function_name="score", *args, **kwargs): self.indent = indent @@ -63,6 +65,11 @@ def interpret(self, expr): os.path.dirname(__file__), "log1p.fs") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.fs") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + return self._cg.finalize_and_get_generated_code() def create_code_generator(self): @@ -78,6 +85,10 @@ def interpret_log1p_expr(self, expr, **kwargs): self.with_log1p_expr = True return super().interpret_log1p_expr(expr, **kwargs) + def interpret_softmax_expr(self, expr, **kwargs): + self.with_softmax_expr = True + return super().interpret_softmax_expr(expr, **kwargs) + def _dump_cache(self): if self._cached_expr_results: for func_name, expr_result in self._cached_expr_results.values(): diff --git a/m2cgen/interpreters/go/interpreter.py b/m2cgen/interpreters/go/interpreter.py index a325d4db..21663ee8 100644 --- a/m2cgen/interpreters/go/interpreter.py +++ b/m2cgen/interpreters/go/interpreter.py @@ -22,9 +22,12 @@ class GoInterpreter(ImperativeToCodeInterpreter, logarithm_function_name = "math.Log" log1p_function_name = "math.Log1p" power_function_name = "math.Pow" + softmax_function_name = "softmax" sqrt_function_name = "math.Sqrt" tanh_function_name = "math.Tanh" + with_softmax_expr = False + def __init__(self, indent=4, function_name="score", *args, **kwargs): self.function_name = function_name @@ -51,7 +54,16 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.go") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.go") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_math_module: self._cg.add_dependency("math") return self._cg.finalize_and_get_generated_code() + + def interpret_softmax_expr(self, expr, **kwargs): + self.with_softmax_expr = True + return super().interpret_softmax_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/haskell/interpreter.py b/m2cgen/interpreters/haskell/interpreter.py index 6a435418..6c6abee4 100644 --- a/m2cgen/interpreters/haskell/interpreter.py +++ b/m2cgen/interpreters/haskell/interpreter.py @@ -21,10 +21,12 @@ class HaskellInterpreter(FunctionalToCodeInterpreter, exponent_function_name = "exp" logarithm_function_name = "log" log1p_function_name = "log1p" + softmax_function_name = "softmax" sqrt_function_name = "sqrt" tanh_function_name = "tanh" with_log1p_expr = False + with_softmax_expr = False def __init__(self, module_name="Model", indent=4, function_name="score", *args, **kwargs): @@ -59,6 +61,11 @@ def interpret(self, expr): os.path.dirname(__file__), "log1p.hs") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.hs") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + self._cg.prepend_code_line(self._cg.tpl_module_definition( module_name=self.module_name)) @@ -77,6 +84,10 @@ def interpret_log1p_expr(self, expr, **kwargs): self.with_log1p_expr = True return super().interpret_log1p_expr(expr, **kwargs) + def interpret_softmax_expr(self, expr, **kwargs): + self.with_softmax_expr = True + return super().interpret_softmax_expr(expr, **kwargs) + def _dump_cache(self): if self._cached_expr_results: self._cg.add_code_line("where") diff --git a/m2cgen/interpreters/interpreter.py b/m2cgen/interpreters/interpreter.py index fd684673..052263ba 100644 --- a/m2cgen/interpreters/interpreter.py +++ b/m2cgen/interpreters/interpreter.py @@ -87,6 +87,7 @@ class ToCodeInterpreter(BaseToCodeInterpreter): logarithm_function_name = NotImplemented log1p_function_name = NotImplemented power_function_name = NotImplemented + softmax_function_name = NotImplemented sqrt_function_name = NotImplemented tanh_function_name = NotImplemented @@ -169,6 +170,17 @@ def interpret_log1p_expr(self, expr, **kwargs): return self._cg.function_invocation( self.log1p_function_name, nested_result) + def interpret_softmax_expr(self, expr, **kwargs): + if self.softmax_function_name is NotImplemented: + return self._do_interpret( + fallback_expressions.softmax(expr.exprs), + **kwargs) + self.with_vectors = True + self.with_math_module = True + nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs] + return self._cg.function_invocation( + self.softmax_function_name, self._cg.vector_init(nested)) + def interpret_sqrt_expr(self, expr, **kwargs): if self.sqrt_function_name is NotImplemented: return self._do_interpret( diff --git a/m2cgen/interpreters/java/interpreter.py b/m2cgen/interpreters/java/interpreter.py index c67532dd..45762949 100644 --- a/m2cgen/interpreters/java/interpreter.py +++ b/m2cgen/interpreters/java/interpreter.py @@ -30,9 +30,12 @@ class JavaInterpreter(ImperativeToCodeInterpreter, logarithm_function_name = "Math.log" log1p_function_name = "Math.log1p" power_function_name = "Math.pow" + softmax_function_name = "softmax" sqrt_function_name = "Math.sqrt" tanh_function_name = "Math.tanh" + with_softmax_expr = False + def __init__(self, package_name=None, class_name="Model", indent=4, function_name="score", *args, **kwargs): self.package_name = package_name @@ -64,9 +67,18 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.java") top_cg.add_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.java") + top_cg.add_code_lines(utils.get_file_content(filename)) + return top_cg.finalize_and_get_generated_code() # Required by SubroutinesMixin to create new code generator for # each subroutine. def create_code_generator(self): return JavaCodeGenerator(indent=self.indent) + + def interpret_softmax_expr(self, expr, **kwargs): + self.with_softmax_expr = True + return super().interpret_softmax_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/javascript/interpreter.py b/m2cgen/interpreters/javascript/interpreter.py index 4ac4afcd..359c68c1 100644 --- a/m2cgen/interpreters/javascript/interpreter.py +++ b/m2cgen/interpreters/javascript/interpreter.py @@ -25,9 +25,12 @@ class JavascriptInterpreter(ImperativeToCodeInterpreter, logarithm_function_name = "Math.log" log1p_function_name = "Math.log1p" power_function_name = "Math.pow" + softmax_function_name = "softmax" sqrt_function_name = "Math.sqrt" tanh_function_name = "Math.tanh" + with_softmax_expr = False + def __init__(self, indent=4, function_name="score", *args, **kwargs): self.indent = indent @@ -53,4 +56,13 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.js") self._cg.add_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.js") + self._cg.add_code_lines(utils.get_file_content(filename)) + return self._cg.finalize_and_get_generated_code() + + def interpret_softmax_expr(self, expr, **kwargs): + self.with_softmax_expr = True + return super().interpret_softmax_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/php/interpreter.py b/m2cgen/interpreters/php/interpreter.py index d25fa5a6..3957827c 100644 --- a/m2cgen/interpreters/php/interpreter.py +++ b/m2cgen/interpreters/php/interpreter.py @@ -23,9 +23,12 @@ class PhpInterpreter(ImperativeToCodeInterpreter, logarithm_function_name = "log" log1p_function_name = "log1p" power_function_name = "pow" + softmax_function_name = "softmax" sqrt_function_name = "sqrt" tanh_function_name = "tanh" + with_softmax_expr = False + def __init__(self, indent=4, function_name="score", *args, **kwargs): self.function_name = function_name @@ -47,6 +50,15 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.php") self._cg.prepend_code_lines(utils.get_file_content(filename)) + if self.with_softmax_expr: + filename = os.path.join( + os.path.dirname(__file__), "softmax.php") + self._cg.prepend_code_lines(utils.get_file_content(filename)) + self._cg.prepend_code_line(" +void softmax(double *x, int size, double *result) { + double max = x[0]; + for (int i = 1; i < size; ++i) { + if (x[i] > max) + max = x[i]; + } + double sum = 0.0; + for (int i = 0; i < size; ++i) { + result[i] = exp(x[i] - max); + sum += result[i]; + } + for (int i = 0; i < size; ++i) + result[i] /= sum; +} +void score(double * input, double * output) { + double var0[2]; + softmax((double[]){2.0, 3.0}, 2, var0); + memcpy(output, var0, 2 * sizeof(double)); +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_c_sharp.py b/tests/interpreters/test_c_sharp.py index 77873dac..b89672c6 100644 --- a/tests/interpreters/test_c_sharp.py +++ b/tests/interpreters/test_c_sharp.py @@ -489,6 +489,41 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + expected_code = """ +using static System.Math; +namespace ML { + public static class Model { + public static double[] Score(double[] input) { + return Softmax(new double[2] {2.0, 3.0}); + } + private static double[] Softmax(double[] x) { + int size = x.Length; + double[] result = new double[size]; + double max = x[0]; + for (int i = 1; i < size; ++i) { + if (x[i] > max) + max = x[i]; + } + double sum = 0.0; + for (int i = 0; i < size; ++i) { + result[i] = Exp(x[i] - max); + sum += result[i]; + } + for (int i = 0; i < size; ++i) + result[i] /= sum; + return result; + } + } +} +""" + + interpreter = CSharpInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_dart.py b/tests/interpreters/test_dart.py index f5643fbe..721f8179 100644 --- a/tests/interpreters/test_dart.py +++ b/tests/interpreters/test_dart.py @@ -566,6 +566,32 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + expected_code = """ +import 'dart:math'; +List score(List input) { + return softmax([2.0, 3.0]); +} +List softmax(List x) { + int size = x.length; + List result = new List(size); + double maxElem = x.reduce(max); + double sum = 0.0; + for (int i = 0; i < size; ++i) { + result[i] = exp(x[i] - maxElem); + sum += result[i]; + } + for (int i = 0; i < size; ++i) + result[i] /= sum; + return result; +} +""" + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_f_sharp.py b/tests/interpreters/test_f_sharp.py index b8f2c595..b9fca010 100644 --- a/tests/interpreters/test_f_sharp.py +++ b/tests/interpreters/test_f_sharp.py @@ -455,6 +455,23 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + expected_code = """ +let private softmax x = + let maxElem = List.reduce max x + let exps = List.map (fun i -> exp (i - maxElem)) x + let sumExps = List.sum exps + List.map (fun i -> i / sumExps) exps +let score (input : double list) = + softmax ([2.0; 3.0]) +""" + + interpreter = FSharpInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_go.py b/tests/interpreters/test_go.py index c5ad342e..e2281326 100644 --- a/tests/interpreters/test_go.py +++ b/tests/interpreters/test_go.py @@ -327,6 +327,39 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + interpreter = interpreters.GoInterpreter() + + expected_code = """ +import "math" +func softmax(x []float64) []float64 { + size := len(x) + result := make([]float64, size) + max := x[0] + for _, v := range x { + if (v > max) { + max = v + } + } + sum := 0.0 + for i := 0; i < size; i++ { + result[i] = math.Exp(x[i] - max) + sum += result[i] + } + for i := 0; i < size; i++ { + result[i] /= sum + } + return result +} +func score(input []float64) []float64 { + return softmax([]float64{2.0, 3.0}) +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_haskell.py b/tests/interpreters/test_haskell.py index 490bb5d6..b8013cd5 100644 --- a/tests/interpreters/test_haskell.py +++ b/tests/interpreters/test_haskell.py @@ -361,6 +361,27 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + expected_code = r""" +module Model where +softmax :: [Double] -> [Double] +softmax x = + let + m = maximum x + exps = map (\i -> exp (i - m)) x + sumExps = sum exps + in map (\i -> i / sumExps) exps +score :: [Double] -> [Double] +score input = + softmax ([2.0, 3.0]) +""" + + interpreter = HaskellInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_java.py b/tests/interpreters/test_java.py index 2ebf0afe..67c5bf08 100644 --- a/tests/interpreters/test_java.py +++ b/tests/interpreters/test_java.py @@ -557,6 +557,38 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + interpreter = interpreters.JavaInterpreter() + + expected_code = """ +public class Model { + public static double[] score(double[] input) { + return softmax(new double[] {2.0, 3.0}); + } + private static double[] softmax(double[] x) { + int size = x.length; + double[] result = new double[size]; + double max = x[0]; + for (int i = 1; i < size; ++i) { + if (x[i] > max) + max = x[i]; + } + double sum = 0.0; + for (int i = 0; i < size; ++i) { + result[i] = Math.exp(x[i] - max); + sum += result[i]; + } + for (int i = 0; i < size; ++i) + result[i] /= sum; + return result; + } +}""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_javascript.py b/tests/interpreters/test_javascript.py index 5dfc3861..9fe9afb0 100644 --- a/tests/interpreters/test_javascript.py +++ b/tests/interpreters/test_javascript.py @@ -342,6 +342,37 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + interpreter = interpreters.JavascriptInterpreter() + + expected_code = """ +function score(input) { + return softmax([2.0, 3.0]); +} +function softmax(x) { + let size = x.length; + let result = new Array(size); + let max = x[0]; + for (let i = 1; i < size; ++i) { + if (x[i] > max) + max = x[i]; + } + let sum = 0.0; + for (let i = 0; i < size; ++i) { + result[i] = Math.exp(x[i] - max); + sum += result[i]; + } + for (let i = 0; i < size; ++i) + result[i] /= sum; + return result; +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/interpreters/test_php.py b/tests/interpreters/test_php.py index d9b00ed1..cc9c5f59 100644 --- a/tests/interpreters/test_php.py +++ b/tests/interpreters/test_php.py @@ -347,6 +347,34 @@ def test_atan_expr(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) +def test_softmax_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + expected_code = """ + max Then + max = x(i) + End If + Next i + Dim sum As Double + sum = 0.0 + For i = LBound(x) To UBound(x) + result(i) = Math.Exp(x(i) - max) + sum = sum + result(i) + Next i + For i = LBound(x) To UBound(x) + result(i) = result(i) / sum + Next i + Softmax = result +End Function +Function Score(ByRef inputVector() As Double) As Double() + Dim var0(1) As Double + var0(0) = 2.0 + var0(1) = 3.0 + Score = Softmax(var0) +End Function +End Module +""" + + interpreter = VisualBasicInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + def test_reused_expr(): reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) diff --git a/tests/test_ast.py b/tests/test_ast.py index cf37911a..af43d85a 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -60,7 +60,7 @@ def test_count_exprs_exclude_list(): ast.BinNumOpType.ADD) ]), ast.IdExpr( - ast.VectorVal([ + ast.SoftmaxExpr([ ast.NumVal(1), ast.NumVal(2), ast.NumVal(3), @@ -106,9 +106,9 @@ def test_exprs_str(): PowExpr(NumVal(2.0),NumVal(3.0),to_reuse=False), TanhExpr(NumVal(1.0),to_reuse=False), BinNumExpr(NumVal(0.0),FeatureRef(0),ADD,to_reuse=False)]), -IdExpr(VectorVal([ +IdExpr(SoftmaxExpr( NumVal(1.0),NumVal(2.0),NumVal(3.0),NumVal(4.0),NumVal(5.0), -NumVal(6.0),NumVal(7.0),NumVal(8.0),FeatureRef(1)]),to_reuse=False),SUB), +NumVal(6.0),NumVal(7.0),NumVal(8.0),FeatureRef(1),to_reuse=False),to_reuse=False),SUB), IfExpr(CompExpr(NumVal(2.0),NumVal(0.0),GT),NumVal(3.0),NumVal(4.0)),MUL) """.strip().replace("\n", "") diff --git a/tests/test_fallback_expressions.py b/tests/test_fallback_expressions.py index 8b58c85b..261ecf88 100644 --- a/tests/test_fallback_expressions.py +++ b/tests/test_fallback_expressions.py @@ -147,3 +147,27 @@ def score(input): """(var5)) + (var6)) * (var7)""") assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_softmax_fallback_expr(): + expr = ast.SoftmaxExpr([ast.NumVal(2.0), ast.NumVal(3.0)]) + + class InterpreterWithoutSoftmax(PythonInterpreter): + softmax_function_name = NotImplemented + + def interpret_softmax_expr(self, expr, **kwargs): + return super(PythonInterpreter, self).interpret_softmax_expr( + expr, **kwargs) + + interpreter = InterpreterWithoutSoftmax() + + expected_code = """ +import math +def score(input): + var0 = math.exp(2.0) + var1 = math.exp(3.0) + var2 = (var0) + (var1) + return [(var0) / (var2), (var1) / (var2)] +""" + + assert_code_equal(interpreter.interpret(expr), expected_code)