# Rooted trees

see also https://mathworld.wolfram.com/RootedTree.html



* may be this one  https://projecteuclid.org/download/pdf_1/euclid.acta/1485892234 would be useful



In [24]:
from typing import Any, List, Tuple
from functools import lru_cache
from collections import defaultdict
from functools import reduce
from operator import mul


def strings_(n: int, d: int):
    ''' generates non-equivalent strings of length n using exactly d letters
    two strings on aphabet of the same size are equivalent if there is a bijection 
    between alphabets that makes strings equal
    using alphabet 1,2,3...
    
    valid input n >= 0, d >= 0
    This is combinatorial problem of paritioning a set into d classes; counted by Bell numbers
    '''
    if n < 0 or d < 0:
        raise ValueError
    if n == 0:
        if d == 0:
            yield []  # empty string
        # else nothing 
    else:
        if d > 0:
            for prefix in strings_(n-1, d-1):
                yield prefix + [d]   # append new letter to each element of the previous set
            for prefix in strings_(n-1, d):
                for letter in range(1, d+1):  # repeat one of existing letters
                    yield prefix + [letter]
        # else nothing

def strings(n):
    ''' generates all non-equivalent strings using d = 1...n different letters'''
    for d in range(1, n+1):
        yield from strings_(n, d)

def  _pmult(n, i_max):
    # yields partitions in the multiplicity format 
    # 1*k1 + 2*k2 + 3*k3 + .. + i * k_i = n 
    # where k_i denotes how many times "i" appears in the partition of n
    # the function is recursion over i_max
    # in this form that partition function of partitions is the Euler function
    # prod 1/(1 - q^i)
    if (i_max==0):
        if (n==0):
            yield []
        else:
            # nothing
            pass
    else:
        for k in range(n//i_max + 1):
            for prev in _pmult(n - k*i_max, i_max - 1):
                yield prev + [k]

@lru_cache(None)
def pmult(n, i_max):
    return list(_pmult(n, i_max))
                

@lru_cache(None)
def sym(n, k):
    # symmetric power n*(n+1)*(n+2).. (n+k-1)/k!
    res = 1
    for i in range(k):
        res = (res*(n+i))//(1+i)
    return res


def symlist(data: List[Any], k: int):
    # symmetric powers as sorted lists of length k of objects picked from data list
    # symplist(['a','b','c'],2) = [['a', 'a'], ['a', 'b'], ['b', 'b'], ['a', 'c'], ['b', 'c'], ['c', 'c']]
    if (k==0):
        yield []
    else:
        for n_last in range(len(data)):
            for head in symlist(data[:n_last+1], k-1):
                yield head + [data[n_last]]
            
            
    

@lru_cache(None)
def ptreemult(n):
    if (n==1):
        return 1
    else:
        ans = 0
        for p in pmult(n-1,n-1):
            term = 1
            for i,k in enumerate(p):
                term *= sym(ptreemult(i+1),k)
            ans += term
        return ans
        

def product(values, offset):
    ''' generates from cartesian products of iterables passed as List
    offset is for recursion efficiency, to avoid list constructions
   
    (if LinkedList implementation, use head :: tail in recursion) 
    '''
    if offset > len(values):
        raise ValueError

    if (offset == len(values)):
        yield []
    else:
        for x in values[offset]:
            for rest in product(values, offset+1):
                yield [x] + rest

def flatten(z: List[List[Any]]):
    return [x for y in z for x in y]
        
def _trees(n):
    if (n==0):
        pass
    elif (n==1):
        yield [0]
    else:
        for p in pmult(n-1,n-1):
            children_opts = []
            for i, k in enumerate(p):
                slist = list(symlist(list(trees(i+1)), k))
                if len(slist) > 0:
                    children_opts.append(slist)
            for children in product(children_opts, 0):
                size = len(flatten(children))
                yield flatten(flatten(children)) + [-size]

@lru_cache(None)                
def trees(n):
    return list(_trees(n))


def postfix_tree_expr(n, d):
    '''cartesian product of uncolored topologies of rooted trees on n nodes 
    and coloring of a set of n elements in d-colors from strings_(n,d)
    
    We'll take DFS ordering of nodes in this product. 
    
    The DFS ordering is reverse of the postfix list.
    '''
    for t in trees(n):
        for s in strings_(n, d):
            yield list(zip(t, reversed(s)))
            
            
import hashlib
hash_fn = lambda x: int(hashlib.md5(x).hexdigest(),16) % 2**64


def tree_hash(e: List[Tuple[int]]):
    '''
    input e: postfix encoding of a colored tree, each node has color coded by a positive integer
    compute hash tree which is invariant under permutations and dublications of the children 
    i.e. make set from the list of hashesh of children and then apply hash function
    invariant under permutations 
    we'll take sum(hashes(set of hashes of children))
    we've could taken also product or any other symmetric function 
    '''
    stack = []
    for narity, node in e:
        if len(stack) < narity:
            raise ValueError(f"Bad postfix expression, not enough children to pop, stack: {stack}")
        previous = hash_fn((sum(set(stack.pop() for _ in range(-narity))) % 2**64).to_bytes(8, byteorder='big'))
        result = hash_fn(previous.to_bytes(8, byteorder='big') + node.to_bytes(8, byteorder='big'))
        stack.append(result)
    if len(stack) != 1:
        raise ValueError(f"Bad postfix expression, finished with stack: {stack} for expr: {e}")
    else:
        return stack.pop()
    
    
def postfix_tree_canonical(n, d):
    # generate postfix trees in lexicographical order up to length n on d variables
    # removes isomorphic trees from the sequence based on tree_hash call
    hashset = set()
    for cnt in range(1, n+1):
        for e in postfix_tree_expr(cnt, d):
            ehash = tree_hash(e)
            if not ehash in hashset:
                hashset.add(ehash)
                yield (n, e
                
                



def bool_denote(e: List[Tuple[int]], arg: List[bool]):
    '''
    Boolean stack evaluator
    e: is postfix tree consisting of list of tuples (-narity, node name)
    where node name are integers 1,2,3..
    arg: is the list of boolean values
    '''
    stack = []
    for narity, node in e:
        if len(stack) < narity:
            raise ValueError(f"Bad postfix expression, not enough children to pop, stack: {stack}")
        result = any(list(not stack.pop() for _ in range(-narity))) or arg[node-1]
        stack.append(result)
    if len(stack) != 1:
        print("bad expression", e)
        raise ValueError(f"Bad postfix expression, finished with stack: {stack} for expr: {e}")
    else:
        return stack.pop()
    
def valid(e):
    ''' brute force proving theorem
    in propositional calculus entails to check that expression evaluates to True for
    all input arguments 
    '''
    n_args = max(node for arity, node in e)
    return all(bool_denote(e, arg) for arg in product([[False,True] for _ in range(n_args)], 0))  



### rooted trees https://oeis.org/A000081

In [25]:
list(len(trees(i)) for i in range(14))       # number of unlabelled rooted trees with n nodes: 

[0, 1, 1, 2, 4, 9, 20, 48, 115, 286, 719, 1842, 4766, 12486]

In [26]:
trees(4)   # notations: -n for n >= 0 means pop n children from the stack, perform operation, push back to stack

[[0, 0, 0, -3], [0, 0, -1, -2], [0, 0, -2, -1], [0, -1, -1, -1]]

### Postfix trees
#### A tree with n nodes in postfix form is a list of n pairs
#### Each pair is (-node.degree, node.variable_id) 

#### The ordering of the pairs is reverse of DFS traversal (i.e. the last pair is the root node)

#### Notice: if we join pairs to continuous list of length 2n, the notations are equivalent to our s expression postfix notations where integer -n <= 0 denotes n-times pop() operation, and integer > 0 is id of an atom to push

In [27]:
for e in postfix_tree_expr(3,2):  
    print(e)

[(0, 2), (0, 1), (-2, 1)]
[(0, 1), (0, 2), (-2, 1)]
[(0, 2), (0, 2), (-2, 1)]
[(0, 2), (-1, 1), (-1, 1)]
[(0, 1), (-1, 2), (-1, 1)]
[(0, 2), (-1, 2), (-1, 1)]


### hashtree function:

In [28]:
for cnt in range(1, 4):
    for e in postfix_tree_expr(cnt,1):
        print(e, tree_hash(e))  # trees with repeating or permuted children have the same hash

[(0, 1)] 4896303482726436129
[(0, 1), (-1, 1)] 13721624596682143247
[(0, 1), (0, 1), (-2, 1)] 13721624596682143247
[(0, 1), (-1, 1), (-1, 1)] 14058027741666673059


## cumulative number of non-equivalent rooted colored trees of size <=  n

- Need to compare with  https://oeis.org/A304486


In [29]:
cuml = list(sum(len(list(postfix_tree_canonical(n,d))) for d in range(1, n+1)) for n in range(0, 9)) # cumulative

In [30]:
# this is the result of running hash_tree with sum(list(hashchildren))
# [1, 2, 9, 49, 347, 2795, 26692, 280591]     
   

In [31]:
[cuml[i] - cuml[i-1] for i in range(1, len(cuml))]   # number of non-isomorphic colored rooted trees of size 1,2,3,.. 

[1, 2, 7, 38, 266, 2148, 20480, 214721]

# Matula Goebel index

see also Abe https://www.sciencedirect.com/science/article/pii/0893965994900531?via%3Dihub
and http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.90.6380&rep=rep1&type=pdf

In [12]:
@lru_cache(None)
def primes(size: int):
    """
        returns a list of first primes of given size
    """
    ps = []
    x = 2
    while (len(ps) < size):
        if all(x % p != 0 for p in ps if p*p <= x):
            ps.append(x)
        x += 1
    return ps

def prime(index: int):
    return primes(max(10000, index))[index-1]

In [13]:
def mg_index(e):
    '''matula goebel index of a rooted tree'''
    stack = []
    for narity, node in e:
        if len(stack) < narity:
            raise ValueError(f"Bad postfix expression, not enough children to pop, stack: {stack}")
        else:
            value = 1
            for _ in range(-narity):
                value *= prime(stack.pop())
        stack.append(value)
    if len(stack) != 1:
        print("bad expression", e)
        raise ValueError(f"Bad postfix expression, finished with stack: {stack} for expr: {e}")
    else:
        return stack.pop()
    
    

In [62]:
from collections import defaultdict
def postfix_tree_canonical_(n, d):
    mg_table = defaultdict(list)
    for e in postfix_tree_canonical(n,d):
        mg_table[mg_index(e)].append(e)
    return mg_table

In [64]:
postfix_tree_canonical_(5,2)

defaultdict(list,
            {2: [[(0, 2), (-1, 1)]],
             4: [[(0, 2), (0, 1), (-2, 1)], [(0, 2), (0, 2), (-2, 1)]],
             3: [[(0, 2), (-1, 1), (-1, 1)],
              [(0, 1), (-1, 2), (-1, 1)],
              [(0, 2), (-1, 2), (-1, 1)]],
             8: [[(0, 2), (0, 1), (0, 1), (-3, 1)],
              [(0, 2), (0, 2), (0, 1), (-3, 1)],
              [(0, 2), (0, 2), (0, 2), (-3, 1)]],
             6: [[(0, 2), (0, 1), (-1, 1), (-2, 1)],
              [(0, 1), (0, 2), (-1, 1), (-2, 1)],
              [(0, 2), (0, 2), (-1, 1), (-2, 1)],
              [(0, 1), (0, 1), (-1, 2), (-2, 1)],
              [(0, 2), (0, 1), (-1, 2), (-2, 1)],
              [(0, 1), (0, 2), (-1, 2), (-2, 1)],
              [(0, 2), (0, 2), (-1, 2), (-2, 1)]],
             7: [[(0, 2), (0, 1), (-2, 1), (-1, 1)],
              [(0, 2), (0, 2), (-2, 1), (-1, 1)],
              [(0, 1), (0, 1), (-2, 2), (-1, 1)],
              [(0, 2), (0, 1), (-2, 2), (-1, 1)],
              [(0, 2), (0, 2), (-2,