Skip to content

Commit

Permalink
continue developing haskell interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Apr 11, 2020
1 parent c79ee10 commit abdd261
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 155 deletions.
3 changes: 2 additions & 1 deletion m2cgen/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
"r": (m2cgen.export_to_r, ["indent", "function_name"]),
"php": (m2cgen.export_to_php, ["indent", "function_name"]),
"dart": (m2cgen.export_to_dart, ["indent", "function_name"]),
"haskell": (m2cgen.export_to_haskell, ["indent", "function_name"]),
"haskell": (m2cgen.export_to_haskell,
["module_name", "indent", "function_name"]),
}


Expand Down
33 changes: 32 additions & 1 deletion m2cgen/interpreters/haskell/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,43 @@ class HaskellCodeGenerator(BaseCodeGenerator):
tpl_module_definition = CodeTemplate("module ${module_name} where")

def __init__(self, *args, **kwargs):
super(HaskellCodeGenerator, self).__init__(*args, **kwargs)
self._func_idx = 0
super().__init__(*args, **kwargs)

def reset_state(self):
super().reset_state()
self._func_idx = 0

def array_index_access(self, array_name, index):
return self.tpl_infix_expression(
left=array_name, op="!!", right=index)

def add_if_statement(self, if_def):
self.add_code_line("if ({})".format(if_def))
self.increase_indent()
self.add_code_line("then")
self.increase_indent()

def add_else_statement(self):
self.decrease_indent()
self.add_code_line("else")
self.increase_indent()

def add_if_termination(self):
self.decrease_indent()
self.decrease_indent()

def get_func_name(self):
func_name = "func" + str(self._func_idx)
self._func_idx += 1
return func_name

def add_function(self, function_name, function_body):
self.add_code_line("{} =".format(function_name))
self.increase_indent()
self.add_code_lines(function_body)
self.decrease_indent()

def function_invocation(self, function_name, *args):
return (function_name + " " +
" ".join(map(lambda x: "({})".format(x), args)))
Expand Down
47 changes: 44 additions & 3 deletions m2cgen/interpreters/haskell/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class HaskellInterpreter(ToCodeInterpreter,
}

exponent_function_name = "exp"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"

def __init__(self, module_name="Model", indent=4, function_name="score",
*args, **kwargs):
self.module_name = module_name
self.indent = indent
self.function_name = function_name

cg = HaskellCodeGenerator(indent=indent)
Expand All @@ -31,9 +33,6 @@ def interpret(self, expr):
self._cg.reset_state()
self._reset_reused_expr_cache()

self._cg.add_code_line(self._cg.tpl_module_definition(
module_name=self.module_name))

args = [(True, self._feature_array_name)]
func_name = self.function_name

Expand All @@ -43,16 +42,58 @@ def interpret(self, expr):
is_scalar_output=expr.output_size == 1):
last_result = self._do_interpret(expr)
self._cg.add_code_line(last_result)
self._dump_cache()

if self.with_linear_algebra:
filename = os.path.join(
os.path.dirname(__file__), "linear_algebra.hs")
self._cg.prepend_code_lines(utils.get_file_content(filename))

self._cg.prepend_code_line(self._cg.tpl_module_definition(
module_name=self.module_name))

return self._cg.code

def interpret_if_expr(self, expr, if_code_gen=None, **kwargs):
if if_code_gen is None:
code_gen = HaskellCodeGenerator(indent=self.indent)
nested = False
else:
code_gen = if_code_gen
nested = True

code_gen.add_if_statement(self._do_interpret(
expr.test, **kwargs))
code_gen.add_code_line(self._do_interpret(
expr.body, if_code_gen=code_gen, **kwargs))
code_gen.add_else_statement()
code_gen.add_code_line(self._do_interpret(
expr.orelse, if_code_gen=code_gen, **kwargs))
code_gen.add_if_termination()

if not nested:
return self._cache_reused_expr(expr, code_gen.code)

def interpret_pow_expr(self, expr, **kwargs):
base_result = self._do_interpret(expr.base_expr, **kwargs)
exp_result = self._do_interpret(expr.exp_expr, **kwargs)
return self._cg.infix_expression(
left=base_result, right=exp_result, op="**")

def _cache_reused_expr(self, expr, expr_result):
if expr in self._cached_expr_results:
return self._cached_expr_results[expr].var_name
else:
func_name = self._cg.get_func_name()
self._cached_expr_results[expr] = utils.CacheResult(
var_name=func_name, expr_result=expr_result)
return func_name

def _dump_cache(self):
if self._cached_expr_results:
self._cg.add_code_line("where")
self._cg.increase_indent()
for func_name, expr_result in self._cached_expr_results.values():
self._cg.add_function(
function_name=func_name, function_body=expr_result)
self._cg.decrease_indent()
1 change: 0 additions & 1 deletion m2cgen/interpreters/haskell/linear_algebra.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
addVectors :: [Double] -> [Double] -> [Double]
addVectors v1 v2 = zipWith (+) v1 v2

mulVectorNumber :: [Double] -> Double -> [Double]
mulVectorNumber v1 num = [i * num | i <- v1]
6 changes: 4 additions & 2 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re

from m2cgen import ast
from m2cgen.interpreters.utils import CacheResult


class BaseInterpreter:
Expand Down Expand Up @@ -41,7 +42,7 @@ def _do_interpret(self, expr, to_reuse=None, **kwargs):
return handler(expr, **kwargs)

if expr in self._cached_expr_results:
return self._cached_expr_results[expr]
return self._cached_expr_results[expr].var_name

result = handler(expr, **kwargs)
return self._cache_reused_expr(expr, result)
Expand Down Expand Up @@ -193,5 +194,6 @@ def handle_nested_expr(nested):
def _cache_reused_expr(self, expr, expr_result):
var_name = self._cg.add_var_declaration(expr.output_size)
self._cg.add_var_assignment(var_name, expr_result, expr.output_size)
self._cached_expr_results[expr] = var_name
self._cached_expr_results[expr] = CacheResult(
var_name=var_name, expr_result=None)
return var_name
6 changes: 6 additions & 0 deletions m2cgen/interpreters/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from collections import namedtuple


CacheResult = namedtuple('CacheResult', ['var_name', 'expr_result'])


def get_file_content(path):
with open(path) as f:
return f.read()
4 changes: 2 additions & 2 deletions tests/e2e/executors/haskell.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@
from tests.e2e.executors import base

EXECUTOR_CODE_TPL = """
module Main where
module ${executor_name} where
import System.Environment (getArgs)
import ${model_name}
main = do
args <- getArgs
let inputArray = [read i::Double | i <- args]
let res = score inputArray
${print_code}
"""

Expand Down Expand Up @@ -51,6 +50,7 @@ def prepare(self):
else:
print_code = EXECUTE_AND_PRINT_SCALAR
executor_code = string.Template(EXECUTOR_CODE_TPL).substitute(
executor_name=self.executor_name,
model_name=self.model_name,
print_code=print_code)
model_code = self.interpreter.interpret(self.model_ast)
Expand Down
Loading

0 comments on commit abdd261

Please sign in to comment.