Skip to content

Commit

Permalink
Merge 21929b9 into 150be0a
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbean committed Dec 13, 2019
2 parents 150be0a + 21929b9 commit 7e4e686
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 42 deletions.
12 changes: 10 additions & 2 deletions CHANGELOG.md
Expand Up @@ -5,12 +5,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `ProofTree.generate_objects_of_length()` implements an algorithm for
generating the objects of a given length by utilising the structure implied
by a proof tree.
- `ProofTreeNode.is_atom()` and `ProofTreeNode.is_epsilon()` methods for
checking if a node represents an atom or epsilon.
### Changed
- Use polynomial algorithm for generating terms in random sampling code.

## [0.2.2] - 2019-09-06
### Added
- `ProofTree.count_objects_of_length()` implements the recurrence relation
implied by the proof tree as long as the strategies used are only disjoint unions,
decompositions, verification or recursion.
implied by the proof tree as long as the strategies used are only disjoint
unions, decompositions, verification or recursion.

### Removed
- Remove the dependency on `permuta`.
Expand Down
126 changes: 87 additions & 39 deletions comb_spec_searcher/proof_tree.py
Expand Up @@ -8,6 +8,7 @@
import sys
import warnings
from functools import reduce
from itertools import product
from operator import add, mul

import sympy
Expand Down Expand Up @@ -39,6 +40,7 @@ def __init__(self, label, eqv_path_labels, eqv_path_comb_classes,
self.terms = []
self.recurse_node = None
self.genf = None
self.objects_of_length = dict()

@property
def logger_kwargs(self):
Expand Down Expand Up @@ -100,26 +102,8 @@ def _error_string(self, parent, children, strat_type, formal_step,
error += "They produced {} many things\n\n".format(children_total)
return error

def random_sample(self, length, tree=None):
def random_sample(self, length):
"""Return a random object of the given length."""
def partitions(n, children_totals):
if n == 0 and not children_totals:
yield []
return
if len(children_totals) == 0 or n < 0:
return
start = children_totals[0]
if len(children_totals) == 1:
if start[n] != 0:
yield [n]
return
for i in range(n + 1):
if start[i] == 0:
continue
else:
for part in partitions(n - i, children_totals[1:]):
yield [i] + part

if self.disjoint_union:
total = self.terms[length]
if total == 0:
Expand All @@ -134,22 +118,27 @@ def partitions(n, children_totals):
for child, child_total in children_totals:
sofar += child_total
if choice <= sofar:
return child.random_sample(length, tree)
return child.random_sample(length)
raise ValueError("You shouldn't be able to get here!")
elif self.decomposition:
non_atom_children = [child for child in self.children
if not child.is_atom()]
number_of_atoms = len(self.children) - len(non_atom_children)
total = self.terms[length]
choice = random.randint(1, total)
children_totals = [child.terms for child in self.children]
sofar = 0
for part in partitions(length, children_totals):
for comp in compositions(length - number_of_atoms,
len(non_atom_children)):
subtotal = 1
for i, terms in zip(part, children_totals):
subtotal *= terms[i]
for i, child in zip(comp, non_atom_children):
subtotal *= child.terms[i]
sofar += subtotal
if choice <= sofar:
sub_objs = [(child.random_sample(i, tree),
comp = list(reversed(comp))
sub_objs = [(child.random_sample((1 if child.is_atom() else
comp.pop())),
child.eqv_path_comb_classes[0])
for i, child in zip(part, self.children)]
for child in self.children]
comb_class = self.eqv_path_comb_classes[-1]
return comb_class.from_parts(*sub_objs,
formal_step=self.formal_step)
Expand All @@ -158,13 +147,62 @@ def partitions(n, children_totals):
return self.eqv_path_comb_classes[-1].random_sample(length)
else:
if self.recursion:
for node in tree.nodes():
if node.label == self.label and not node.recursion:
return node.random_sample(length, tree)
return self.recurse_node.random_sample(length)
raise NotImplementedError(("Random sampler only implemented for "
"disjoint union and cartesian "
"product."))

def generate_objects_of_length(self, n):
"""Yield objects of given length."""
if n in self.objects_of_length:
yield from self.objects_of_length[n]
return
else:
res = []
# TODO: handle equivalence path nodes (somewhat assume length 1)
if self.disjoint_union:
for child in self.children:
for path in child.generate_objects_of_length(n):
yield path
res.append(path)
elif self.decomposition:
comb_class = self.eqv_path_comb_classes[-1]
child_comb_classes = [child.eqv_path_comb_classes[0]
for child in self.children]
number_atoms = sum(1 for child in child_comb_classes
if child.is_atom())
for comp in compositions(n - number_atoms,
len(self.children) - number_atoms):
i, actual_comp = 0, []
for child in child_comb_classes:
if child.is_atom():
actual_comp.append(1)
else:
actual_comp.append(comp[i])
i += 1
for child_objs in product(*[child.generate_objects_of_length(i)
for child, i in zip(self.children,
actual_comp)]):
parts = [(x, y)
for x, y in zip(child_objs, child_comb_classes)]
path = comb_class.from_parts(*parts)
yield path
res.append(path)
elif self.strategy_verified:
for path in self.eqv_path_comb_classes[-1].objects_of_length(n):
yield path
res.append(path)
else:
if self.recursion:
for path in self.recurse_node.generate_objects_of_length(n):
yield path
res.append(path)
else:
raise NotImplementedError(("Object generator only implemented "
"for disjoint union and cartesian "
"product."))
self.objects_of_length[n] = res

def sanity_check(self, length, of_length=None):
if of_length is None:
raise ValueError("of_length is undefined.")
Expand Down Expand Up @@ -303,7 +341,7 @@ def count_objects_of_length(self, n):
pos_children = set()
children = [] # A list of children that are not atoms
for child in self.children:
if child.eqv_path_comb_classes[-1].is_atom():
if child.is_atom():
atoms += 1
else:
if child.eqv_path_comb_classes[-1].is_positive():
Expand All @@ -323,9 +361,9 @@ def count_objects_of_length(self, n):
break
ans += tmp
elif self.strategy_verified:
if self.eqv_path_comb_classes[-1].is_epsilon():
if self.is_epsilon():
return 1 if n == 0 else 0
elif self.eqv_path_comb_classes[-1].is_atom():
elif self.is_atom():
return 1 if n == 1 else 0
else:
self._ensure_terms(n)
Expand Down Expand Up @@ -359,6 +397,14 @@ def _ensure_terms(self, n, expand_extra=50):
coeffs = taylor_expand(self.genf, n=n+expand_extra)
self.terms.extend(coeffs[len(self.terms):])

def is_atom(self):
return any(comb_class.is_atom()
for comb_class in self.eqv_path_comb_classes)

def is_epsilon(self):
return any(comb_class.is_epsilon()
for comb_class in self.eqv_path_comb_classes)

@property
def eqv_path_objects(self):
"""This is for reverse compatability."""
Expand Down Expand Up @@ -590,17 +636,15 @@ def get_min_poly(self, **kwargs):
raise RuntimeError(("Incorrect minimum polynomial\n" +
str(basis)))

def random_sample(self, length=100, solved=False):
def random_sample(self, length=100):
if any(len(node.terms) < length + 1 for node in self.nodes()):
logger.info(("Computing terms"))
funcs = self.get_genf(only_root=False)
self._recursion_setup()
for node in self.nodes():
if len(node.terms) < length + 1:
logger.info(("Taylor expanding function {} to length {}."
"".format(node.get_function(), length)))
node.terms = taylor_expand(funcs[node.label], length)
node.terms = [node.count_objects_of_length(i)
for i in range(length + 1)]
logger.info("Walking through tree")
return self.root.random_sample(length, self)
return self.root.random_sample(length)

def nodes(self, root=None):
if root is None:
Expand Down Expand Up @@ -791,6 +835,10 @@ def count_objects_of_length(self, n):
self._recursion_setup()
return self.root.count_objects_of_length(n)

def generate_objects_of_length(self, n):
self._recursion_setup()
yield from self.root.generate_objects_of_length(n)

def __eq__(self, other):
return all(node1 == node2
for node1, node2 in zip(self.nodes(), other.nodes()))
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -53,6 +53,6 @@ def read(fname):
'pytest-pep8==1.0.6',
'pytest-repeat==0.8.0',
'docutils==0.15.2',
'Pygments==2.5.0'
'Pygments==2.5.2'
]
)

0 comments on commit 7e4e686

Please sign in to comment.