Skip to content

Commit

Permalink
continue implementing
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Sep 23, 2020
1 parent e742507 commit 8e2b115
Show file tree
Hide file tree
Showing 38 changed files with 692 additions and 115 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Expand Up @@ -4,4 +4,5 @@ recursive-include m2cgen linear_algebra.*
recursive-include m2cgen log1p.*
recursive-include m2cgen tanh.*
recursive-include m2cgen atan.*
recursive-include m2cgen softmax.*
global-exclude *.py[cod]
10 changes: 5 additions & 5 deletions m2cgen/assemblers/boosting.py
Expand Up @@ -72,8 +72,7 @@ def _assemble_multi_class_output(self, estimator_params):
for i, e in enumerate(splits)
]

proba_exprs = self._multi_class_convert_output(exprs)
return ast.VectorVal(proba_exprs)
return self._multi_class_convert_output(exprs)

def _assemble_bin_class_output(self, estimator_params):
# Base score is calculated based on
Expand All @@ -97,7 +96,7 @@ def _final_transform(self, ast_to_transform):
return ast_to_transform

def _multi_class_convert_output(self, exprs):
return fallback_expressions.softmax(exprs)
return ast.SoftmaxExpr(exprs)

def _bin_class_convert_output(self, expr, to_reuse=True):
return fallback_expressions.sigmoid(expr, to_reuse=to_reuse)
Expand Down Expand Up @@ -250,8 +249,9 @@ def _multi_class_convert_output(self, exprs):
return supported_objectives[self.objective_name](exprs)

def _multi_class_sigmoid_transform(self, exprs):
return [self._bin_class_sigmoid_transform(expr, to_reuse=False)
for expr in exprs]
return ast.VectorVal([
self._bin_class_sigmoid_transform(expr, to_reuse=False)
for expr in exprs])

def _bin_class_convert_output(self, expr, to_reuse=True):
supported_objectives = {
Expand Down
4 changes: 2 additions & 2 deletions m2cgen/assemblers/fallback_expressions.py
Expand Up @@ -190,7 +190,7 @@ def softmax(exprs):
exp_exprs = [ast.ExpExpr(e, to_reuse=True) for e in exprs]
exp_sum_expr = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD, *exp_exprs, to_reuse=True)
return [
return ast.VectorVal([
ast.BinNumExpr(e, exp_sum_expr, ast.BinNumOpType.DIV)
for e in exp_exprs
]
])
29 changes: 26 additions & 3 deletions m2cgen/ast.py
Expand Up @@ -266,6 +266,29 @@ def __hash__(self):
return hash(tuple(self.exprs))


class SoftmaxExpr(VectorExpr):

def __init__(self, exprs, to_reuse=False):
assert all(e.output_size == 1 for e in exprs), (
"All expressions for SoftmaxExpr must be scalar")

self.exprs = exprs
self.to_reuse = to_reuse
self.output_size = len(exprs)

def __str__(self):
args = ",".join([str(e) for e in self.exprs])
return f"SoftmaxExpr({args},to_reuse={self.to_reuse})"

def __eq__(self, other):
return (type(other) is SoftmaxExpr and
self.output_size == other.output_size and
all(i == j for i, j in zip(self.exprs, other.exprs)))

def __hash__(self):
return hash(tuple(self.exprs))


class BinVectorExpr(VectorExpr, BinExpr):

def __init__(self, left, right, op):
Expand Down Expand Up @@ -384,9 +407,9 @@ def __hash__(self):

NESTED_EXPRS_MAPPINGS = [
((BinExpr, CompExpr), lambda e: [e.left, e.right]),
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((PowExpr), lambda e: [e.base_expr, e.exp_expr]),
((VectorVal, SoftmaxExpr), lambda e: e.exprs),
((IfExpr), lambda e: [e.test, e.body, e.orelse]),
((AbsExpr, AtanExpr, ExpExpr, IdExpr, LogExpr, Log1pExpr,
SqrtExpr, TanhExpr),
lambda e: [e.expr]),
Expand Down
23 changes: 23 additions & 0 deletions m2cgen/interpreters/c/interpreter.py
Expand Up @@ -23,9 +23,12 @@ class CInterpreter(ImperativeToCodeInterpreter,
logarithm_function_name = "log"
log1p_function_name = "log1p"
power_function_name = "pow"
softmax_function_name = "softmax"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_softmax_expr = False

def __init__(self, indent=4, function_name="score", *args, **kwargs):
self.function_name = function_name

Expand Down Expand Up @@ -61,6 +64,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.c")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.c")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_vectors:
self._cg.add_dependency("<string.h>")

Expand Down Expand Up @@ -99,3 +107,18 @@ def interpret_bin_vector_num_expr(self, expr, **kwargs):
self._cg.add_code_line(f"{func_inv};")

return var_name

# Do the same things for softmax as for linear algebra.
def interpret_softmax_expr(self, expr, **kwargs):
self.with_vectors = True
self.with_softmax_expr = True

var_name = self._cg.add_var_declaration(expr.output_size)
nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs]
func_inv = self._cg.function_invocation(
self.softmax_function_name,
self._cg.vector_init(nested),
expr.output_size,
var_name)
self._cg.add_code_line(f"{func_inv};")
return var_name
11 changes: 11 additions & 0 deletions m2cgen/interpreters/c_sharp/interpreter.py
Expand Up @@ -24,10 +24,12 @@ class CSharpInterpreter(ImperativeToCodeInterpreter,
logarithm_function_name = "Log"
log1p_function_name = "Log1p"
power_function_name = "Pow"
softmax_function_name = "Softmax"
sqrt_function_name = "Sqrt"
tanh_function_name = "Tanh"

with_log1p_expr = False
with_softmax_expr = False

def __init__(self, namespace="ML", class_name="Model", indent=4,
function_name="Score", *args, **kwargs):
Expand Down Expand Up @@ -66,6 +68,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "log1p.cs")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.cs")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_math_module:
self._cg.add_dependency("System.Math")

Expand All @@ -74,3 +81,7 @@ def interpret(self, expr):
def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)
11 changes: 11 additions & 0 deletions m2cgen/interpreters/dart/interpreter.py
Expand Up @@ -27,10 +27,12 @@ class DartInterpreter(ImperativeToCodeInterpreter,
logarithm_function_name = "log"
log1p_function_name = "log1p"
power_function_name = "pow"
softmax_function_name = "softmax"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_log1p_expr = False
with_softmax_expr = False
with_tanh_expr = False

def __init__(self, indent=4, function_name="score", *args, **kwargs):
Expand Down Expand Up @@ -63,6 +65,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "log1p.dart")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.dart")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_tanh_expr:
filename = os.path.join(
os.path.dirname(__file__), "tanh.dart")
Expand All @@ -86,3 +93,7 @@ def interpret_log1p_expr(self, expr, **kwargs):
def interpret_tanh_expr(self, expr, **kwargs):
self.with_tanh_expr = True
return super().interpret_tanh_expr(expr, **kwargs)

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)
11 changes: 11 additions & 0 deletions m2cgen/interpreters/f_sharp/interpreter.py
Expand Up @@ -30,10 +30,12 @@ class FSharpInterpreter(FunctionalToCodeInterpreter,
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
softmax_function_name = "softmax"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_log1p_expr = False
with_softmax_expr = False

def __init__(self, indent=4, function_name="score", *args, **kwargs):
self.indent = indent
Expand Down Expand Up @@ -63,6 +65,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "log1p.fs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.fs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

return self._cg.finalize_and_get_generated_code()

def create_code_generator(self):
Expand All @@ -78,6 +85,10 @@ def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)

def _dump_cache(self):
if self._cached_expr_results:
for func_name, expr_result in self._cached_expr_results.values():
Expand Down
12 changes: 12 additions & 0 deletions m2cgen/interpreters/go/interpreter.py
Expand Up @@ -22,9 +22,12 @@ class GoInterpreter(ImperativeToCodeInterpreter,
logarithm_function_name = "math.Log"
log1p_function_name = "math.Log1p"
power_function_name = "math.Pow"
softmax_function_name = "softmax"
sqrt_function_name = "math.Sqrt"
tanh_function_name = "math.Tanh"

with_softmax_expr = False

def __init__(self, indent=4, function_name="score", *args, **kwargs):
self.function_name = function_name

Expand All @@ -51,7 +54,16 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.go")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.go")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_math_module:
self._cg.add_dependency("math")

return self._cg.finalize_and_get_generated_code()

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)
11 changes: 11 additions & 0 deletions m2cgen/interpreters/haskell/interpreter.py
Expand Up @@ -21,10 +21,12 @@ class HaskellInterpreter(FunctionalToCodeInterpreter,
exponent_function_name = "exp"
logarithm_function_name = "log"
log1p_function_name = "log1p"
softmax_function_name = "softmax"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

with_log1p_expr = False
with_softmax_expr = False

def __init__(self, module_name="Model", indent=4, function_name="score",
*args, **kwargs):
Expand Down Expand Up @@ -59,6 +61,11 @@ def interpret(self, expr):
os.path.dirname(__file__), "log1p.hs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.hs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

self._cg.prepend_code_line(self._cg.tpl_module_definition(
module_name=self.module_name))

Expand All @@ -77,6 +84,10 @@ def interpret_log1p_expr(self, expr, **kwargs):
self.with_log1p_expr = True
return super().interpret_log1p_expr(expr, **kwargs)

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)

def _dump_cache(self):
if self._cached_expr_results:
self._cg.add_code_line("where")
Expand Down
12 changes: 12 additions & 0 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -87,6 +87,7 @@ class ToCodeInterpreter(BaseToCodeInterpreter):
logarithm_function_name = NotImplemented
log1p_function_name = NotImplemented
power_function_name = NotImplemented
softmax_function_name = NotImplemented
sqrt_function_name = NotImplemented
tanh_function_name = NotImplemented

Expand Down Expand Up @@ -169,6 +170,17 @@ def interpret_log1p_expr(self, expr, **kwargs):
return self._cg.function_invocation(
self.log1p_function_name, nested_result)

def interpret_softmax_expr(self, expr, **kwargs):
if self.softmax_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.softmax(expr.exprs),
**kwargs)
self.with_vectors = True
self.with_math_module = True
nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs]
return self._cg.function_invocation(
self.softmax_function_name, self._cg.vector_init(nested))

def interpret_sqrt_expr(self, expr, **kwargs):
if self.sqrt_function_name is NotImplemented:
return self._do_interpret(
Expand Down
12 changes: 12 additions & 0 deletions m2cgen/interpreters/java/interpreter.py
Expand Up @@ -30,9 +30,12 @@ class JavaInterpreter(ImperativeToCodeInterpreter,
logarithm_function_name = "Math.log"
log1p_function_name = "Math.log1p"
power_function_name = "Math.pow"
softmax_function_name = "softmax"
sqrt_function_name = "Math.sqrt"
tanh_function_name = "Math.tanh"

with_softmax_expr = False

def __init__(self, package_name=None, class_name="Model", indent=4,
function_name="score", *args, **kwargs):
self.package_name = package_name
Expand Down Expand Up @@ -64,9 +67,18 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.java")
top_cg.add_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.java")
top_cg.add_code_lines(utils.get_file_content(filename))

return top_cg.finalize_and_get_generated_code()

# Required by SubroutinesMixin to create new code generator for
# each subroutine.
def create_code_generator(self):
return JavaCodeGenerator(indent=self.indent)

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)
12 changes: 12 additions & 0 deletions m2cgen/interpreters/javascript/interpreter.py
Expand Up @@ -25,9 +25,12 @@ class JavascriptInterpreter(ImperativeToCodeInterpreter,
logarithm_function_name = "Math.log"
log1p_function_name = "Math.log1p"
power_function_name = "Math.pow"
softmax_function_name = "softmax"
sqrt_function_name = "Math.sqrt"
tanh_function_name = "Math.tanh"

with_softmax_expr = False

def __init__(self, indent=4, function_name="score",
*args, **kwargs):
self.indent = indent
Expand All @@ -53,4 +56,13 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.js")
self._cg.add_code_lines(utils.get_file_content(filename))

if self.with_softmax_expr:
filename = os.path.join(
os.path.dirname(__file__), "softmax.js")
self._cg.add_code_lines(utils.get_file_content(filename))

return self._cg.finalize_and_get_generated_code()

def interpret_softmax_expr(self, expr, **kwargs):
self.with_softmax_expr = True
return super().interpret_softmax_expr(expr, **kwargs)

0 comments on commit 8e2b115

Please sign in to comment.