diff --git a/CHANGELOG.md b/CHANGELOG.md index edd0c0c58..456bdf0c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added +- ProofTree.count\_objects\_of\_length() implements the recurrence relation + implied by the proof tree as long as the strategies used are only disjoint unions, + decompositions, verification or recursion. + ### Fixed - Update the readme and test it - Added missing equation for "F_root" case diff --git a/comb_spec_searcher/combinatorial_class.py b/comb_spec_searcher/combinatorial_class.py index 88418e4a5..f2d88777d 100644 --- a/comb_spec_searcher/combinatorial_class.py +++ b/comb_spec_searcher/combinatorial_class.py @@ -46,6 +46,29 @@ def objects_of_length(self, length): "debug settings and for initial conditions" " for computing the generating function") + def is_epsilon(self): + """Returns True if the generating function equals 1""" + raise NotImplementedError("If you want to use the " + "'count_objects_of_length' function " + "for a proof tree then you must implement " + "'is_epsilon' for your combinatorial class.") + + def is_atom(self): + """Returns True if the generating function equals x""" + raise NotImplementedError("If you want to use the " + "'count_objects_of_length' function " + "for a proof tree then you must implement " + "'is_epsilon', 'is_atom' and 'is_positive' " + "for your combinatorial class.") + + def is_positive(self): + """Returns True if the constant term of the generating function is 0""" + raise NotImplementedError("If you want to use the " + "'count_objects_of_length' function " + "for a proof tree then you must implement " + "'is_epsilon', 'is_atom' and 'is_positive' " + "for your combinatorial class.") + def from_dict(self): """Return combinatorial class from the jsonable object.""" raise NotImplementedError("This function is need to reinstantiate a " diff --git a/comb_spec_searcher/proof_tree.py b/comb_spec_searcher/proof_tree.py index 8b1ad9c03..eb7d488c4 100644 --- a/comb_spec_searcher/proof_tree.py +++ b/comb_spec_searcher/proof_tree.py @@ -16,8 +16,8 @@ from permuta.misc.ordered_set_partitions import partitions_of_n_of_size_k from .tree_searcher import Node as tree_searcher_node -from .utils import (check_equation, check_poly, get_solution, maple_equations, - taylor_expand) +from .utils import (check_equation, check_poly, compositions, get_solution, + maple_equations, taylor_expand) class ProofTreeNode(object): @@ -37,6 +37,8 @@ def __init__(self, label, eqv_path_labels, eqv_path_comb_classes, self.formal_step = formal_step self.sympy_function = None self.terms = [] + self.recurse_node = None + self.genf = None @property def logger_kwargs(self): @@ -284,6 +286,86 @@ def get_equation(self, root_func=None, root_class=None, rhs = sympy.Function("DOITYOURSELF")(sympy.abc.x) return sympy.Eq(lhs, rhs) + def count_objects_of_length(self, n): + ''' + Calculates objects of lenght in each node according to the + recurrence relation implied by the proof tree. Only works + for disjoint union, decomposition, strategy verified and recursion. + + Verified nodes are expected to have a known generating function. + ''' + if n < 0: + return 0 + if len(self.terms) > n: + return self.terms[n] + + ans = 0 + if self.disjoint_union: + ans = sum(child.count_objects_of_length(n) + for child in self.children) + elif self.decomposition: + # Number of children that are just the atom + atoms = 0 + # Indices of children that are positive (do not contain epsilon) + pos_children = set() + children = [] # A list of children that are not atoms + for child in self.children: + if child.eqv_path_comb_classes[-1].is_atom(): + atoms += 1 + else: + if child.eqv_path_comb_classes[-1].is_positive(): + pos_children.add(len(children)) + children.append(child) + + for comp in compositions(n-atoms, len(children)): + # A composition is only valid if all positive children + # get more than 0 atoms. + if any(c == 0 for i, c in enumerate(comp) + if i in pos_children): + continue + tmp = 1 + for i, child in enumerate(children): + tmp *= child.count_objects_of_length(comp[i]) + if tmp == 0: + break + ans += tmp + elif self.strategy_verified: + if self.eqv_path_comb_classes[-1].is_epsilon(): + return 1 if n == 0 else 0 + elif self.eqv_path_comb_classes[-1].is_atom(): + return 1 if n == 1 else 0 + else: + self._ensure_terms(n) + return self.terms[n] + elif self.recursion: + if self.recurse_node: + return self.recurse_node.count_objects_of_length(n) + else: + raise ValueError(("Recursing to a subtree that is not" + " contained in the subtree from the" + " root object that was called on.")) + else: + raise NotImplementedError(("count_objects_of_length() is only " + "defined for disjoint union, " + "cartesian product, recursion " + "and strategy verified.")) + if len(self.terms) != n: + self.terms.extend([0]*(n-len(self.terms))) + self.terms.append(ans) + return ans + + def _ensure_terms(self, n, expand_extra=50): + """ + Ensures that self.terms contains the n-th term. If not it will + use the generating function to compute terms up to n+expand_extra. + """ + if len(self.terms) > n: + return + if self.genf is None: + self.genf = self.eqv_path_comb_classes[-1].get_genf() + coeffs = taylor_expand(self.genf, n=n+expand_extra) + self.terms.extend(coeffs[len(self.terms):]) + @property def eqv_path_objects(self): """This is for reverse compatability.""" @@ -304,6 +386,7 @@ def __init__(self, root): raise TypeError("Root must be a ProofTreeNode.") self.root = root self._of_length_cache = {} + self._fixed_recursion = False @property def logger_kwargs(self): @@ -697,6 +780,24 @@ def from_comb_spec_searcher_node(cls, root, css, in_label=None): extra={"processname": "css_to_proof_tree"}) raise NotImplementedError("Only handle cartesian and disjoint") + def _recursion_setup(self): + label_to_node = dict() + + for node in self.nodes(): + if not node.recursion: + label_to_node[node.label] = node + + for node in self.nodes(): + if node.recursion: + node.recurse_node = label_to_node[node.label] + + self._fixed_recursion = True + + def count_objects_of_length(self, n): + if not self._fixed_recursion: + self._recursion_setup() + return self.root.count_objects_of_length(n) + def __eq__(self, other): return all(node1 == node2 for node1, node2 in zip(self.nodes(), other.nodes())) diff --git a/comb_spec_searcher/utils.py b/comb_spec_searcher/utils.py index 9edca0746..5c408a4e7 100644 --- a/comb_spec_searcher/utils.py +++ b/comb_spec_searcher/utils.py @@ -128,3 +128,32 @@ def maple_equations(root_func, root_class, eqs): s += "count := {}:".format([len(list(root_class.objects_of_length(i))) for i in range(6)]) return s + + +def compositions(n, k): + # Credit to: + # https://pythonhosted.org/combalg-py/_modules/combalg/combalg.html + if n < 0: + raise ValueError("Can't make compositions of negative numbers") + if k < 0: + raise ValueError("Can't make compositions into a negative " + " number of parts") + if k > n or (n > k and k == 0): + raise ValueError("Can't make compositions of {} into " + "{} parts".format(n, k)) + if k == 0: + return + t = n + h = 0 + a = [0]*k + a[0] = n + yield list(a) + while a[k-1] != n: + if t != 1: + h = 0 + t = a[h] + a[h] = 0 + a[0] = t-1 + a[h+1] += 1 + h += 1 + yield list(a)