diff --git a/m2cgen/interpreters/c/interpreter.py b/m2cgen/interpreters/c/interpreter.py index 8c5a2ee6..51c89cb8 100644 --- a/m2cgen/interpreters/c/interpreter.py +++ b/m2cgen/interpreters/c/interpreter.py @@ -63,7 +63,7 @@ def interpret(self, expr): if self.with_math_module: self._cg.add_dependency("") - return self._cg.code + return self._cg.finalize_and_get_generated_code() # Both methods supporting linear algebra do several things: # diff --git a/m2cgen/interpreters/c_sharp/interpreter.py b/m2cgen/interpreters/c_sharp/interpreter.py index 6e32b763..d72d599e 100644 --- a/m2cgen/interpreters/c_sharp/interpreter.py +++ b/m2cgen/interpreters/c_sharp/interpreter.py @@ -58,4 +58,4 @@ def interpret(self, expr): if self.with_math_module: self._cg.add_dependency("System.Math") - return self._cg.code + return self._cg.finalize_and_get_generated_code() diff --git a/m2cgen/interpreters/code_generator.py b/m2cgen/interpreters/code_generator.py index 36401c55..dc5dc523 100644 --- a/m2cgen/interpreters/code_generator.py +++ b/m2cgen/interpreters/code_generator.py @@ -1,4 +1,6 @@ +from io import StringIO from string import Template +from weakref import finalize class CodeTemplate: @@ -29,11 +31,38 @@ class BaseCodeGenerator: def __init__(self, indent=4): self._indent = indent + self._code_buf = None self.reset_state() def reset_state(self): self._current_indent = 0 - self.code = "" + self._finalize_buffer() + self._code_buf = StringIO() + self._code = None + self._finalizer = finalize(self, self._finalize_buffer) + + def _finalize_buffer(self): + if self._code_buf is not None and not self._code_buf.closed: + self._code_buf.close() + + def _write_to_code_buffer(self, text, prepend=False): + if self._code_buf.closed: + raise BufferError( + "Cannot modify code after getting generated code and " + "closing the underlying buffer!\n" + "Call reset_state() to allocate new buffer.") + if prepend: + self._code_buf.seek(0) + old_content = self._code_buf.read() + self._code_buf.seek(0) + text += old_content + self._code_buf.write(text) + + def finalize_and_get_generated_code(self): + if not self._code_buf.closed: + self._code = self._code_buf.getvalue() + self._finalize_buffer() + return self._code if self._code is not None else "" def increase_indent(self): self._current_indent += self._indent @@ -48,22 +77,24 @@ def decrease_indent(self): def add_code_line(self, line): if not line: return - indent = " " * self._current_indent - self.code += indent + line + "\n" + self.add_code_lines([line.strip()]) def add_code_lines(self, lines): if isinstance(lines, str): lines = lines.strip().split("\n") indent = " " * self._current_indent - self.code += indent + "\n{}".format(indent).join(lines) + "\n" + self._write_to_code_buffer( + indent + "\n{}".format(indent).join(lines) + "\n") def prepend_code_line(self, line): - self.code = line + "\n" + self.code + if not line: + return + self.prepend_code_lines([line.strip()]) def prepend_code_lines(self, lines): if isinstance(lines, str): lines = lines.strip().split("\n") - self.code = "\n".join(lines) + "\n" + self.code + self._write_to_code_buffer("\n".join(lines) + "\n", prepend=True) # Following methods simply compute expressions using templates without # changing result. diff --git a/m2cgen/interpreters/dart/interpreter.py b/m2cgen/interpreters/dart/interpreter.py index 34d658e3..967aa296 100644 --- a/m2cgen/interpreters/dart/interpreter.py +++ b/m2cgen/interpreters/dart/interpreter.py @@ -62,7 +62,7 @@ def interpret(self, expr): if self.with_math_module: self._cg.add_dependency("dart:math") - return self._cg.code + return self._cg.finalize_and_get_generated_code() def interpret_tanh_expr(self, expr, **kwargs): self.with_tanh_expr = True diff --git a/m2cgen/interpreters/go/interpreter.py b/m2cgen/interpreters/go/interpreter.py index 95ca0ee2..74826f1b 100644 --- a/m2cgen/interpreters/go/interpreter.py +++ b/m2cgen/interpreters/go/interpreter.py @@ -50,4 +50,4 @@ def interpret(self, expr): if self.with_math_module: self._cg.add_dependency("math") - return self._cg.code + return self._cg.finalize_and_get_generated_code() diff --git a/m2cgen/interpreters/haskell/interpreter.py b/m2cgen/interpreters/haskell/interpreter.py index 577aa8af..3c37010e 100644 --- a/m2cgen/interpreters/haskell/interpreter.py +++ b/m2cgen/interpreters/haskell/interpreter.py @@ -52,7 +52,7 @@ def interpret(self, expr): self._cg.prepend_code_line(self._cg.tpl_module_definition( module_name=self.module_name)) - return self._cg.code + return self._cg.finalize_and_get_generated_code() def interpret_if_expr(self, expr, if_code_gen=None, **kwargs): if if_code_gen is None: @@ -72,7 +72,8 @@ def interpret_if_expr(self, expr, if_code_gen=None, **kwargs): code_gen.add_if_termination() if not nested: - return self._cache_reused_expr(expr, code_gen.code) + return self._cache_reused_expr( + expr, code_gen.finalize_and_get_generated_code()) def interpret_pow_expr(self, expr, **kwargs): base_result = self._do_interpret(expr.base_expr, **kwargs) diff --git a/m2cgen/interpreters/java/interpreter.py b/m2cgen/interpreters/java/interpreter.py index 01b4d41a..92d03007 100644 --- a/m2cgen/interpreters/java/interpreter.py +++ b/m2cgen/interpreters/java/interpreter.py @@ -60,7 +60,7 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.java") top_cg.add_code_lines(utils.get_file_content(filename)) - return top_cg.code + return top_cg.finalize_and_get_generated_code() # Required by SubroutinesMixin to create new code generator for # each subroutine. diff --git a/m2cgen/interpreters/javascript/interpreter.py b/m2cgen/interpreters/javascript/interpreter.py index 09fb05e7..366b894b 100644 --- a/m2cgen/interpreters/javascript/interpreter.py +++ b/m2cgen/interpreters/javascript/interpreter.py @@ -49,4 +49,4 @@ def interpret(self, expr): os.path.dirname(__file__), "linear_algebra.js") self._cg.add_code_lines(utils.get_file_content(filename)) - return self._cg.code + return self._cg.finalize_and_get_generated_code() diff --git a/m2cgen/interpreters/mixins.py b/m2cgen/interpreters/mixins.py index 199633bb..fd4bc454 100644 --- a/m2cgen/interpreters/mixins.py +++ b/m2cgen/interpreters/mixins.py @@ -187,7 +187,7 @@ def _process_subroutine(self, subroutine): last_result = self._do_interpret(subroutine.expr) self._cg.add_return_statement(last_result) - return self._cg.code + return self._cg.finalize_and_get_generated_code() def _get_subroutine_name(self): subroutine_name = "subroutine" + str(self._subroutine_idx) diff --git a/m2cgen/interpreters/php/interpreter.py b/m2cgen/interpreters/php/interpreter.py index 06d0416b..3f288c1d 100644 --- a/m2cgen/interpreters/php/interpreter.py +++ b/m2cgen/interpreters/php/interpreter.py @@ -45,4 +45,4 @@ def interpret(self, expr): self._cg.prepend_code_line("