Skip to content

Commit

Permalink
more internal juggling
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcinKonowalczyk committed Apr 11, 2023
1 parent 3d835a2 commit cda32eb
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 99 deletions.
6 changes: 4 additions & 2 deletions psll/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.2"
__version__ = "0.1.3"

import sys

Expand All @@ -13,5 +13,7 @@ class PsllSyntaxError(SyntaxError):
from . import ( # noqa: F401, E402
preprocessor,
lexer,
compiler,
macros,
build,
optimisers,
)
15 changes: 9 additions & 6 deletions psll/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,13 @@ def valid_output_file(args, ext=".pyra"):
#
# ======================================================================

from . import preprocessor # noqa: E402
from . import lexer # noqa: E402
from . import compiler # noqa: E402
from . import optimisers # noqa: E402
from . import ( # noqa: E402
preprocessor,
lexer,
macros,
build,
optimisers,
)


def main(args): # pragma: no cover
Expand All @@ -135,15 +138,15 @@ def main(args): # pragma: no cover
# names = find_variable_names(ast)
# print('variables:',variables)

ast = compiler.apply_processing_stack(ast, full_names=args.full_names)
ast = macros.apply_processing_stack(ast, full_names=args.full_names)
# print(ast)
# TODO Make optimisation options mutually exclusive
if args.considerate_optimisation:
ast = optimisers.considerate_optimisation(ast, max_iter=None)
if args.greedy_optimisation:
ast = optimisers.greedy_optimisation(ast, max_iter=None)

program = compiler.compile(ast)
program = build.build(ast)
if args.verbose:
print("Pyramid scheme:", program, sep="\n")

Expand Down
47 changes: 47 additions & 0 deletions psll/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from functools import reduce
import operator
from functools import lru_cache as cached
from .ascii_trees import Pyramid

# ===================================================================================
#
# ##### ## ## #### ## #####
# ## ## ## ## ## ## ## ##
# ##### ## ## ## ## ## ##
# ## ## ## ## ## ## ## ##
# ##### ##### #### ###### #####
#
# ===================================================================================


@cached(maxsize=10000)
def build_tree(ast):
"""Build the call tree from the leaves to the root"""

if isinstance(ast, str):
return Pyramid.from_text(ast)
elif ast is None:
return None
elif isinstance(ast, tuple):
if len(ast) != 3:
raise RuntimeError(
f"Invalid structure of the abstract syntax tree. ({ast})"
)
if not isinstance(ast[0], str):
raise RuntimeError(
"Invalid abstract syntax tree. The first element of each node must be"
f" a string, not a {type(ast[0])}"
)
return build_tree(ast[0]) + (build_tree(ast[1]), build_tree(ast[2]))
else:
raise TypeError(
"Abstract syntax tree must be represented by a list (or just a string) not"
f" a {type(ast)}"
)


def build(ast) -> str:
"""Build the program from the abstract syntax tree"""
program = str(reduce(operator.add, (build_tree(a) for a in ast)))
# Remove excessive whitespace
return "\n".join(line[1:].rstrip() for line in program.split("\n"))
48 changes: 0 additions & 48 deletions psll/compiler.py → psll/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
from more_itertools import windowed

from functools import partial, reduce
from functools import lru_cache as cached
import operator
from string import ascii_letters

from . import PsllSyntaxError
from .ascii_trees import Pyramid
from . import lexer


Expand Down Expand Up @@ -477,48 +474,3 @@ def apply_processing_stack(ast: Node, full_names: bool = False) -> Node:
"""Apply the processing stack to the ast"""
stack = __processing_stack__[1:] if full_names else __processing_stack__
return reduce(lambda x, y: y(x), [ast] + list(stack)) # type: ignore


# =========================================================================================
#
# #### ##### ### ### ##### ## ## #####
# ## ## ## ## # # ## ## ## ## ## ##
# ## ## ## ## ## ## ##### ## ## #####
# ## ## ## ## ## ## ## ## ##
# #### ##### ## ## ## ## ###### #####
#
# =========================================================================================


@cached(maxsize=10000)
def build_tree(ast):
"""Build the call tree from the leaves to the root"""

if isinstance(ast, str):
return Pyramid.from_text(ast)
elif ast is None:
return None
elif isinstance(ast, tuple):
if len(ast) != 3:
raise RuntimeError(
f"Invalid structure of the abstract syntax tree. ({ast})"
)
if not isinstance(ast[0], str):
raise RuntimeError(
"Invalid abstract syntax tree. The first element of each node must be"
f" a string, not a {type(ast[0])}"
)
return build_tree(ast[0]) + (build_tree(ast[1]), build_tree(ast[2]))
else:
raise TypeError(
"Abstract syntax tree must be represented by a list (or just a string) not"
f" a {type(ast)}"
)


# TODO Refactor
def compile(ast) -> str:
"""Compile text into trees"""
program = str(reduce(operator.add, (build_tree(a) for a in ast)))
# Remove excessive whitespace
return "\n".join(line[1:].rstrip() for line in program.split("\n"))
11 changes: 6 additions & 5 deletions psll/optimisers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from more_itertools import windowed_complete
import operator

from . import compiler
from . import build


def greedy_optimisation(ast, verbose: bool = True, max_iter: Optional[int] = None):
"""Greedily insert empty trees into the abstract syntax tree"""
Expand All @@ -28,9 +29,9 @@ def candidates(ast):
if max_iter and iter_count > max_iter:
break

N = len(compiler.compile(ast))
N = len(build.build(ast))
for candidate in candidates(ast):
M = len(compiler.compile(candidate))
M = len(build.build(candidate))
if M < N:
if verbose:
print(f"{iter_count} | Old len: {N} | New len: {M}")
Expand Down Expand Up @@ -74,8 +75,8 @@ def candidates(ast):
if max_iter and iter_count > max_iter:
break

N = len(compiler.compile(ast))
lengths = ((len(compiler.compile(c)), c) for c in candidates(ast))
N = len(build.build(ast))
lengths = ((len(build.build(c)), c) for c in candidates(ast))
M, candidate = min(lengths, key=operator.itemgetter(0))
if M < N:
if verbose:
Expand Down

0 comments on commit cda32eb

Please sign in to comment.