From 1102e04696db97e3394d1fdc5c0d200b7f331629 Mon Sep 17 00:00:00 2001 From: Daniel Xu Date: Thu, 6 Dec 2018 16:26:59 -0500 Subject: [PATCH] Added a subtrees feature --- ast.py | 42 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/ast.py b/ast.py index 72129b1..b0a58be 100644 --- a/ast.py +++ b/ast.py @@ -20,6 +20,9 @@ def __init__(self, sym, parent=None, left=None, right=None): self.parent = parent self.right = right + def is_leaf(self): + return self.left is None and self.right is None + def copy(self): return node(self.sym, self.parent, self.left, self.right) @@ -38,8 +41,8 @@ class astree(): After constructing the tree completely, one can use `preorder`, `inorder` and `postorder` to convert the expression to prefix, infix and postfix format. """ - def __init__(self): - self.root = None + def __init__(self, root=None): + self.root = root self.cur = None def add(self, sym): @@ -232,6 +235,30 @@ def _level_traversal(root, level, tlist): _level_traversal(root.right, level+1, tlist) +def subtrees(a, roots=[], max_depth=None): + """ + Get all possible subtrees from root, left and then right. It ends if + max depth is given and the depth is reached, otherwise, it will return all + subtrees. If roots are given, only the trees that with roots that are inside + the given list will be returned. + """ + li = [] + _subtree(a.root, 0, 100, li) + result = [] + for tree in li: + if len(roots) == 0 or tree.root.sym in roots: + result.append(tree) + return result + + +def _subtree(root, depth, max_depth, li): + if root is None or root.is_leaf() or depth == max_depth: + return + li.append(astree(root)) + _subtree(root.left, depth+1, max_depth, li) + _subtree(root.right, depth+1, max_depth, li) + + def _max_depth(n): if n is None: return 0 @@ -311,10 +338,13 @@ def main(): def test(): - a = build("-1*--2^-2") - view(a) - print(evaluate(a)) + a = build("5*(2+3)") + l = subtrees(a, roots=[]) + for i in l: + view(i) + print(evaluate(i)) + if __name__ == "__main__": - main() + test()