In [8]:
import re
import torch
from torch_geometric.data import Data, InMemoryDataset, DataLoader

In [26]:
class Tree:
    def __init__(self, root, parent=None):
        self.root = root
        self.parents = [parent] if parent else []
        self.subtrees = []
        self.subtree_repr = None
        
    def add_subtrees(self, *subtrees):
        self.subtrees.extend(subtrees)
    
    def add_subtree(self, subtree):
        self.subtrees.append(subtree)
        
    def view_subtree(self):
        return self.subtree_repr
    
    def __len__(self):
        n_nodes = 0
        q = []
        q.append(self)
        seen = [self]
        while q:
            node = q.pop(0)
            n_nodes += 1
            for child in node.subtrees:
                if child in seen:
                    continue
                else:
                    q.append(child)
                    seen.append(child)
        return n_nodes
        
    def __str__(self):
        return f'Tree(root={self.root}, parents={self.parents}, size={self.__len__()})'

In [10]:
class Node:
    def __init__(self, label, index=None, value=None):
        self.label = label
        self.index = index
        self.value = value
        
    def __str__(self):
        return f'({self.label}, index={self.index})'
    
    def __repr__(self):
        return f'({self.label}, index={self.index})'

In [11]:
def make_tree(proof):
    # Proof tree format in regex below
    m = re.search('(nodes {.*})+.*theorem_in_database.*', proof)
    proof = m.group(1)
    
    tree = {}
    sections = []
    bracket_depth = 0
    start_of_next = 0
    is_naming = True
    name = []
    
    for i, char in enumerate(proof):
        # First find the name of the 
        if is_naming:
            if char == ' ' and len(name) == 0:
                continue
            elif char.isalpha():
                name.append(char)
            elif char == ' ':
                is_naming = False
                start_of_next = i
        else:
            if char == '{':
                bracket_depth += 1
            elif char == '}':
                bracket_depth -= 1
                if bracket_depth == 0:
                    sections.append((''.join(name), proof[start_of_next:i+1]))
                    name = []
                    is_naming = True
        
    for _, section in sections:
        print(section)
        print()
        
    fingerprints_seen = set()
    tactics_used = list()
    initial_fingerprint = None
    
    for _, section in sections:
        theorem_fingerprint = None
        found_tactic_before = False
        children = []
        split_section = section.split()
        print(split_section)

        for i, token in enumerate(split_section):
            if token == 'tactic:':
                tactic = split_section[i+1]
                if found_tactic_before:
                    print('Wait a minute... you\'d already found the tactic!')
                found_tactic_before = False
                tactics_used.append(tactic)

                
            # ======== Good code =========
            if token == 'fingerprint:':
                fingerprints_seen.add(split_section[i+1])
                
                if initial_fingerprint is None:
                    initial_fingerprint = split_section[i+1]
                    
                else:
                    if theorem_fingerprint is not None:
                        children.append(split_section[i+1])
                        
                    else:
                        theorem_fingerprint = split_section[i+1]

        tree[theorem_fingerprint] = children
                
                
    return tree, fingerprints_seen, tactics_used

In [12]:
def process_theorem(theorem):
    x = theorem
    x = x.replace(r"\'", "'")
    x = x.replace('(', ' ( ')
    x = x.replace(')', ' ) ')
    x = x.split()
    return x

In [24]:
# TODO: Assuming that never see '(' twice in a row. Is this valid?
from collections import deque

distinct_features = set()

def bfs(tree, store_index=False):
    index = 0
    queue = deque()
    visit_order = []
    queue.append(tree)
    
    if store_index:
        tree.root.index = index
        index += 1
        
    while queue:
        x = queue.popleft()
        visit_order.append(x.root)
        
        for subtree in x.subtrees:
            queue.append(subtree)
            
            if store_index:
                subtree.root.index = index
                index += 1
                
    return tuple(visit_order)


# TODO: Change special expressions to reduced forms
def thm_to_tree(theorem):
    subexpressions = dict()
    tree = Tree(root='', parent=None)
    current_tree = tree
    i_sym = 0
    while i_sym < len(theorem):
        sym = theorem[i_sym]
        if sym == '(':
            new_subtree = Tree(root=Node(theorem[i_sym + 1]), parent=current_tree)
            current_tree.add_subtree(new_subtree)
            current_tree = new_subtree
            #i_sym += 1
        elif sym == ')':
            current_tree.subtree_repr = bfs(current_tree)
            current_tree = current_tree.parents[0]
        else:
            distinct_features.add(sym)
            current_tree.add_subtree(Tree(root=Node(theorem[i_sym]), parent=current_tree))

        i_sym += 1
    
    final_tree = tree.subtrees[0]
    final_tree.parents = None
    
    bfs(final_tree, store_index=True)
    return final_tree
    

def merge_subexpressions(tree):
    subexpressions = dict()
    stack = []
    stack.append(tree)
    
    while stack:
        t = stack.pop()
        if t.subtree_repr in subexpressions:
            parent = t.parents[0]
            for i, subtree in enumerate(parent.subtrees):
                if subtree.subtree_repr == t.subtree_repr:
                    parent.subtrees[i] = subexpressions[t.subtree_repr]
                    if parent not in subexpressions[t.subtree_repr].parents:
                        subexpressions[t.subtree_repr].parents.append(parent)
        else:
            for subtree in t.subtrees:
                stack.append(subtree)
    return tree
    

In [14]:
DISTINCT_FEATURES = ['fun', 'B', 'a', 'A']

def graph_to_data(tree, distinct_features):
    edges = []
    features = []
    
    stack = []
    stack.append(tree)
    processed_subtrees = []
    while stack:
        x = stack.pop()
        features.append(x.root.label)
        
        if x.parents:
            for parent in x.parents:
                edges.append([parent.root.index, x.root.index])
        
        for subtree in x.subtrees[::-1]:
            if subtree in processed_subtrees:
                continue
                
            stack.append(subtree)
            edges.append([x.root.index, subtree.root.index])
            

    features = torch.tensor([[distinct_features.index(x)] for x in features])

    edges = torch.tensor(edges)
    edges = edges.permute(1,0)
    
    #datum = Data(x=features, edge_index=edges)
    
    return features, edges

In [15]:
def get_fingerprints(section):
    # Go through and find fingerprint of current conclusion, and of subgoals
    theorem_fingerprint = None
    getting_subgoals = False
    child_fingerprints = []
    split_section = section.split()
    for idx, token in enumerate(split_section):
        if token == 'subgoals':
            getting_subgoals = True
        if token == 'fingerprint:':
            if theorem_fingerprint is None:
                theorem_fingerprint = split_section[idx + 1]
            elif getting_subgoals is True:
                child_fingerprints.append(split_section[idx + 1])
    return theorem_fingerprint, child_fingerprints


def get_theorems(proof):
    # Find all theorems within
    tree = None
    theorems = []
    index = dict()
    x = re.split('theorem_in_database', proof)
    proof = x[0]
    m = re.split('nodes', proof)
    for section in m:
        if section:
            theorem = re.search('conclusion: "([^"]*)"', section)
            theorems.append(theorem.group(1))
            theorem = theorem.group(1)
            fingerprint, children = get_fingerprints(section)    # fingerprint matches theorem, can be used to index into tree
            
            if tree is None:
                # Create root node for proof tree, i.e. final conclusion, with subgoals
                tree = Tree(root=Node(
                                    label=fingerprint,
                                    value=theorem
                            ))
                index[fingerprint] = tree.root
                
                for child in children:
                    subtree = Tree(root=Node(label=child),
                                  parent=tree)
                    tree.add_subtree(subtree)
                    index[child] = subtree
                
            else:
                # Add subtree with subgoals
                subtree = index[fingerprint]
                subtree.root.value = theorem    # Only fingerprint is stored initially. We now have theorem, so update.
                
                for child in children:
                    subsubtree = Tree(root=Node(label=child),
                                     parent=subtree)
                    subtree.add_subtree(subsubtree)
                    index[child] = subsubtree
                
                
    return theorems, tree


def get_tree(proof):
    # Get root (first conclusion)
    # Recursively build tree, getting subgoals of root
    return NotImplemented

In [9]:
with open('../deephol-data/deepmath/deephol/proofs/human/vocab_goal_ls.txt', 'r') as f:
    vocab_goal_ls = set(x.strip() for x in f.readlines()[1::])
    
with open('../deephol-data/deepmath/deephol/proofs/human/vocab_ls.txt', 'r') as f:
    vocab_ls = set(x.strip() for x in f.readlines()[1::])
    
with open('../deephol-data/deepmath/deephol/proofs/human/vocab_thms_ls.txt', 'r') as f:
    vocab_thms_ls = set(x.strip() for x in f.readlines()[1::])

In [53]:
with open('../deephol-data/deepmath/deephol/proofs/human/train/prooflogs-00000-of-00600.pbtxt', 'r') as f:
    i = 0
    for line in f:
        if i == 40:
            print(f'******** {i} ********')
            theorems, tree = get_theorems(line)
            print(len(tree))
            break
        i += 1

******** 40 ********
9


In [16]:
thm = '(a (a (c (fun (bool) (fun (bool) (bool))) =) (a (a (c (fun (real) (fun (real) (bool))) real_le) (a (c (fun (cart (real) (1)) (real)) drop) (v (cart (real) (1)) b))) (a (c (fun (cart (real) (1)) (real)) drop) (v (cart (real) (1)) d)))) (c (bool) T))'

def make_data():
    datapoints = []
    for i in range(600):
        if i % 10 == 0:
            print(i)
        label = str(i)
        if i // 10 == 0:
            label = '0' + label
        if i // 100 == 0:
            label = '0' + label
        with open(f'../deephol-data/deepmath/deephol/proofs/human/train/prooflogs-00{label}-of-00600.pbtxt', 'r') as f:
            for line in f:
                theorems, tree = get_theorems(line)
                datapoints.append((tree.root.value, len(tree)))
    return datapoints

In [17]:
datapoints = make_data()
#print(len(datapoints))
#print(datapoints[40])

0
10
20
30
40
50
60
70
80
90
100
110
120
130
140
150
160
170
180
190
200
210
220
230
240
250
260
270
280
290
300
310
320
330
340
350
360
370
380
390
400
410
420
430
440
450
460
470
480
490
500
510
520
530
540
550
560
570
580
590


In [29]:
test_thm = '(fun (a A B) (a A (a A B)))'
print(test_thm)
thm = process_theorem(test_thm)
print(thm)
thm_tree = thm_to_tree(thm)
print(thm_tree)
merged_thm_tree = merge_subexpressions(thm_tree)
print(merged_thm_tree)
print

(fun (a A B) (a A (a A B)))
['(', 'fun', '(', 'a', 'A', 'B', ')', '(', 'a', 'A', '(', 'a', 'A', 'B', ')', ')', ')']
Tree(root=(fun, index=0), parents=None, size=13)
Tree(root=(fun, index=0), parents=None, size=13)


<function print>

In [28]:
for idx, (thm, _) in enumerate(datapoints):
    if idx % 1000 == 0:
        print(f'{idx} / {len(datapoints)}')
    thm = process_theorem(thm)
    thm_tree = thm_to_tree(thm)
    merge_subexpressions(thm_tree)

0 / 13677
1000 / 13677


KeyboardInterrupt: 

In [79]:
print(len(vocab_ls), len(vocab_goal_ls), len(vocab_thms_ls))

1254 1107 1191


In [117]:
print(distinct_features - vocab_ls)

{'/\\\\', '\\\\/'}


In [1]:
counter = dict()
for _, y in datapoints:
    if y in counter:
        counter[y] += 1
    else:
        counter[y] = 1
counter = list(counter.items())
counter.sort(key=lambda x: x[1], reverse=True)
percentages = [(x, y/len(datapoints)*100) for x,y in counter]
percentages

NameError: name 'datapoints' is not defined

In [13]:
class TopLevelProofDataset(InMemoryDataset):
    def __init__(self, root='', transform=None, pre_transform=None):
        super(TopLevelProofDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    @property
    def processed_file_names(self):
        return ['../top_level_proofs.dataset']
    
    def download(self):
        pass
    
    def process(self):
        data_list = []
        
        for thm, y in datapoints:
            thm = process_theorem(thm)
            tree = thm_to_tree(thm)
            
            x, edge_index = graph_to_data(tree, list(distinct_features))
            
            data = Data(x=x, edge_index=edge_index, y=y)
            data_list.append(data)
        
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [14]:
dataset = TopLevelProofDataset()

In [15]:
loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [16]:
for batch in loader:
    print(batch)
    

Batch(batch=[195], edge_index=[2, 194], x=[195, 1], y=[1])
Batch(batch=[280], edge_index=[2, 279], x=[280, 1], y=[1])
Batch(batch=[2379], edge_index=[2, 2378], x=[2379, 1], y=[1])
Batch(batch=[277], edge_index=[2, 276], x=[277, 1], y=[1])
Batch(batch=[291], edge_index=[2, 290], x=[291, 1], y=[1])
Batch(batch=[190], edge_index=[2, 189], x=[190, 1], y=[1])
Batch(batch=[65], edge_index=[2, 64], x=[65, 1], y=[1])
Batch(batch=[203], edge_index=[2, 202], x=[203, 1], y=[1])
Batch(batch=[180], edge_index=[2, 179], x=[180, 1], y=[1])
Batch(batch=[163], edge_index=[2, 162], x=[163, 1], y=[1])
Batch(batch=[207], edge_index=[2, 206], x=[207, 1], y=[1])
Batch(batch=[387], edge_index=[2, 386], x=[387, 1], y=[1])
Batch(batch=[222], edge_index=[2, 221], x=[222, 1], y=[1])
Batch(batch=[303], edge_index=[2, 302], x=[303, 1], y=[1])
Batch(batch=[82], edge_index=[2, 81], x=[82, 1], y=[1])
Batch(batch=[588], edge_index=[2, 587], x=[588, 1], y=[1])
Batch(batch=[732], edge_index=[2, 731], x=[732, 1], y=[1])


Batch(batch=[1670], edge_index=[2, 1669], x=[1670, 1], y=[1])
Batch(batch=[327], edge_index=[2, 326], x=[327, 1], y=[1])
Batch(batch=[949], edge_index=[2, 948], x=[949, 1], y=[1])
Batch(batch=[379], edge_index=[2, 378], x=[379, 1], y=[1])
Batch(batch=[250], edge_index=[2, 249], x=[250, 1], y=[1])
Batch(batch=[162], edge_index=[2, 161], x=[162, 1], y=[1])
Batch(batch=[437], edge_index=[2, 436], x=[437, 1], y=[1])
Batch(batch=[223], edge_index=[2, 222], x=[223, 1], y=[1])
Batch(batch=[626], edge_index=[2, 625], x=[626, 1], y=[1])
Batch(batch=[107], edge_index=[2, 106], x=[107, 1], y=[1])
Batch(batch=[82], edge_index=[2, 81], x=[82, 1], y=[1])
Batch(batch=[124], edge_index=[2, 123], x=[124, 1], y=[1])
Batch(batch=[326], edge_index=[2, 325], x=[326, 1], y=[1])
Batch(batch=[137], edge_index=[2, 136], x=[137, 1], y=[1])
Batch(batch=[997], edge_index=[2, 996], x=[997, 1], y=[1])
Batch(batch=[175], edge_index=[2, 174], x=[175, 1], y=[1])
Batch(batch=[208], edge_index=[2, 207], x=[208, 1], y=[1

Batch(batch=[344], edge_index=[2, 343], x=[344, 1], y=[1])
Batch(batch=[934], edge_index=[2, 933], x=[934, 1], y=[1])
Batch(batch=[482], edge_index=[2, 481], x=[482, 1], y=[1])
Batch(batch=[191], edge_index=[2, 190], x=[191, 1], y=[1])
Batch(batch=[821], edge_index=[2, 820], x=[821, 1], y=[1])
Batch(batch=[308], edge_index=[2, 307], x=[308, 1], y=[1])
Batch(batch=[93], edge_index=[2, 92], x=[93, 1], y=[1])
Batch(batch=[288], edge_index=[2, 287], x=[288, 1], y=[1])
Batch(batch=[233], edge_index=[2, 232], x=[233, 1], y=[1])
Batch(batch=[93], edge_index=[2, 92], x=[93, 1], y=[1])
Batch(batch=[987], edge_index=[2, 986], x=[987, 1], y=[1])
Batch(batch=[669], edge_index=[2, 668], x=[669, 1], y=[1])
Batch(batch=[159], edge_index=[2, 158], x=[159, 1], y=[1])
Batch(batch=[163], edge_index=[2, 162], x=[163, 1], y=[1])
Batch(batch=[1283], edge_index=[2, 1282], x=[1283, 1], y=[1])
Batch(batch=[357], edge_index=[2, 356], x=[357, 1], y=[1])
Batch(batch=[73], edge_index=[2, 72], x=[73, 1], y=[1])
Bat

Batch(batch=[124], edge_index=[2, 123], x=[124, 1], y=[1])
Batch(batch=[167], edge_index=[2, 166], x=[167, 1], y=[1])
Batch(batch=[292], edge_index=[2, 291], x=[292, 1], y=[1])
Batch(batch=[100], edge_index=[2, 99], x=[100, 1], y=[1])
Batch(batch=[247], edge_index=[2, 246], x=[247, 1], y=[1])
Batch(batch=[163], edge_index=[2, 162], x=[163, 1], y=[1])
Batch(batch=[137], edge_index=[2, 136], x=[137, 1], y=[1])
Batch(batch=[280], edge_index=[2, 279], x=[280, 1], y=[1])
Batch(batch=[369], edge_index=[2, 368], x=[369, 1], y=[1])
Batch(batch=[316], edge_index=[2, 315], x=[316, 1], y=[1])
Batch(batch=[631], edge_index=[2, 630], x=[631, 1], y=[1])
Batch(batch=[137], edge_index=[2, 136], x=[137, 1], y=[1])
Batch(batch=[124], edge_index=[2, 123], x=[124, 1], y=[1])
Batch(batch=[339], edge_index=[2, 338], x=[339, 1], y=[1])
Batch(batch=[595], edge_index=[2, 594], x=[595, 1], y=[1])
Batch(batch=[45], edge_index=[2, 44], x=[45, 1], y=[1])
Batch(batch=[221], edge_index=[2, 220], x=[221, 1], y=[1])
B

Batch(batch=[522], edge_index=[2, 521], x=[522, 1], y=[1])
Batch(batch=[139], edge_index=[2, 138], x=[139, 1], y=[1])
Batch(batch=[330], edge_index=[2, 329], x=[330, 1], y=[1])
Batch(batch=[1092], edge_index=[2, 1091], x=[1092, 1], y=[1])
Batch(batch=[484], edge_index=[2, 483], x=[484, 1], y=[1])
Batch(batch=[727], edge_index=[2, 726], x=[727, 1], y=[1])
Batch(batch=[259], edge_index=[2, 258], x=[259, 1], y=[1])
Batch(batch=[441], edge_index=[2, 440], x=[441, 1], y=[1])
Batch(batch=[709], edge_index=[2, 708], x=[709, 1], y=[1])
Batch(batch=[112], edge_index=[2, 111], x=[112, 1], y=[1])
Batch(batch=[181], edge_index=[2, 180], x=[181, 1], y=[1])
Batch(batch=[93], edge_index=[2, 92], x=[93, 1], y=[1])
Batch(batch=[148], edge_index=[2, 147], x=[148, 1], y=[1])
Batch(batch=[288], edge_index=[2, 287], x=[288, 1], y=[1])
Batch(batch=[167], edge_index=[2, 166], x=[167, 1], y=[1])
Batch(batch=[304], edge_index=[2, 303], x=[304, 1], y=[1])
Batch(batch=[628], edge_index=[2, 627], x=[628, 1], y=[1

Batch(batch=[214], edge_index=[2, 213], x=[214, 1], y=[1])
Batch(batch=[109], edge_index=[2, 108], x=[109, 1], y=[1])
Batch(batch=[91], edge_index=[2, 90], x=[91, 1], y=[1])
Batch(batch=[993], edge_index=[2, 992], x=[993, 1], y=[1])
Batch(batch=[219], edge_index=[2, 218], x=[219, 1], y=[1])
Batch(batch=[225], edge_index=[2, 224], x=[225, 1], y=[1])
Batch(batch=[373], edge_index=[2, 372], x=[373, 1], y=[1])
Batch(batch=[506], edge_index=[2, 505], x=[506, 1], y=[1])
Batch(batch=[1833], edge_index=[2, 1832], x=[1833, 1], y=[1])
Batch(batch=[93], edge_index=[2, 92], x=[93, 1], y=[1])
Batch(batch=[236], edge_index=[2, 235], x=[236, 1], y=[1])
Batch(batch=[484], edge_index=[2, 483], x=[484, 1], y=[1])
Batch(batch=[98], edge_index=[2, 97], x=[98, 1], y=[1])
Batch(batch=[797], edge_index=[2, 796], x=[797, 1], y=[1])
Batch(batch=[201], edge_index=[2, 200], x=[201, 1], y=[1])
Batch(batch=[975], edge_index=[2, 974], x=[975, 1], y=[1])
Batch(batch=[590], edge_index=[2, 589], x=[590, 1], y=[1])
Bat

Batch(batch=[240], edge_index=[2, 239], x=[240, 1], y=[1])
Batch(batch=[124], edge_index=[2, 123], x=[124, 1], y=[1])
Batch(batch=[621], edge_index=[2, 620], x=[621, 1], y=[1])
Batch(batch=[95], edge_index=[2, 94], x=[95, 1], y=[1])
Batch(batch=[1039], edge_index=[2, 1038], x=[1039, 1], y=[1])
Batch(batch=[971], edge_index=[2, 970], x=[971, 1], y=[1])
Batch(batch=[137], edge_index=[2, 136], x=[137, 1], y=[1])
Batch(batch=[345], edge_index=[2, 344], x=[345, 1], y=[1])
Batch(batch=[170], edge_index=[2, 169], x=[170, 1], y=[1])
Batch(batch=[325], edge_index=[2, 324], x=[325, 1], y=[1])
Batch(batch=[137], edge_index=[2, 136], x=[137, 1], y=[1])
Batch(batch=[680], edge_index=[2, 679], x=[680, 1], y=[1])
Batch(batch=[137], edge_index=[2, 136], x=[137, 1], y=[1])
Batch(batch=[607], edge_index=[2, 606], x=[607, 1], y=[1])
Batch(batch=[117], edge_index=[2, 116], x=[117, 1], y=[1])
Batch(batch=[217], edge_index=[2, 216], x=[217, 1], y=[1])
Batch(batch=[246], edge_index=[2, 245], x=[246, 1], y=[1

Batch(batch=[692], edge_index=[2, 691], x=[692, 1], y=[1])
Batch(batch=[163], edge_index=[2, 162], x=[163, 1], y=[1])
Batch(batch=[102], edge_index=[2, 101], x=[102, 1], y=[1])
Batch(batch=[287], edge_index=[2, 286], x=[287, 1], y=[1])
Batch(batch=[249], edge_index=[2, 248], x=[249, 1], y=[1])
Batch(batch=[391], edge_index=[2, 390], x=[391, 1], y=[1])
Batch(batch=[275], edge_index=[2, 274], x=[275, 1], y=[1])
Batch(batch=[426], edge_index=[2, 425], x=[426, 1], y=[1])
Batch(batch=[181], edge_index=[2, 180], x=[181, 1], y=[1])
Batch(batch=[288], edge_index=[2, 287], x=[288, 1], y=[1])
Batch(batch=[1325], edge_index=[2, 1324], x=[1325, 1], y=[1])
Batch(batch=[287], edge_index=[2, 286], x=[287, 1], y=[1])
Batch(batch=[93], edge_index=[2, 92], x=[93, 1], y=[1])
Batch(batch=[188], edge_index=[2, 187], x=[188, 1], y=[1])
Batch(batch=[232], edge_index=[2, 231], x=[232, 1], y=[1])
Batch(batch=[287], edge_index=[2, 286], x=[287, 1], y=[1])
Batch(batch=[532], edge_index=[2, 531], x=[532, 1], y=[1

Batch(batch=[361], edge_index=[2, 360], x=[361, 1], y=[1])
Batch(batch=[681], edge_index=[2, 680], x=[681, 1], y=[1])
Batch(batch=[177], edge_index=[2, 176], x=[177, 1], y=[1])
Batch(batch=[1105], edge_index=[2, 1104], x=[1105, 1], y=[1])
Batch(batch=[294], edge_index=[2, 293], x=[294, 1], y=[1])
Batch(batch=[437], edge_index=[2, 436], x=[437, 1], y=[1])
Batch(batch=[93], edge_index=[2, 92], x=[93, 1], y=[1])
Batch(batch=[1135], edge_index=[2, 1134], x=[1135, 1], y=[1])
Batch(batch=[555], edge_index=[2, 554], x=[555, 1], y=[1])
Batch(batch=[221], edge_index=[2, 220], x=[221, 1], y=[1])
Batch(batch=[1361], edge_index=[2, 1360], x=[1361, 1], y=[1])
Batch(batch=[485], edge_index=[2, 484], x=[485, 1], y=[1])
Batch(batch=[735], edge_index=[2, 734], x=[735, 1], y=[1])
Batch(batch=[117], edge_index=[2, 116], x=[117, 1], y=[1])
Batch(batch=[141], edge_index=[2, 140], x=[141, 1], y=[1])
Batch(batch=[126], edge_index=[2, 125], x=[126, 1], y=[1])
Batch(batch=[308], edge_index=[2, 307], x=[308, 1]

Batch(batch=[138], edge_index=[2, 137], x=[138, 1], y=[1])
Batch(batch=[202], edge_index=[2, 201], x=[202, 1], y=[1])
Batch(batch=[171], edge_index=[2, 170], x=[171, 1], y=[1])
Batch(batch=[497], edge_index=[2, 496], x=[497, 1], y=[1])
Batch(batch=[137], edge_index=[2, 136], x=[137, 1], y=[1])
Batch(batch=[724], edge_index=[2, 723], x=[724, 1], y=[1])
Batch(batch=[449], edge_index=[2, 448], x=[449, 1], y=[1])
Batch(batch=[635], edge_index=[2, 634], x=[635, 1], y=[1])
Batch(batch=[232], edge_index=[2, 231], x=[232, 1], y=[1])
Batch(batch=[93], edge_index=[2, 92], x=[93, 1], y=[1])
Batch(batch=[292], edge_index=[2, 291], x=[292, 1], y=[1])
Batch(batch=[491], edge_index=[2, 490], x=[491, 1], y=[1])
Batch(batch=[636], edge_index=[2, 635], x=[636, 1], y=[1])
Batch(batch=[466], edge_index=[2, 465], x=[466, 1], y=[1])
Batch(batch=[252], edge_index=[2, 251], x=[252, 1], y=[1])
Batch(batch=[311], edge_index=[2, 310], x=[311, 1], y=[1])
Batch(batch=[163], edge_index=[2, 162], x=[163, 1], y=[1])


Batch(batch=[421], edge_index=[2, 420], x=[421, 1], y=[1])
Batch(batch=[578], edge_index=[2, 577], x=[578, 1], y=[1])
Batch(batch=[113], edge_index=[2, 112], x=[113, 1], y=[1])
Batch(batch=[717], edge_index=[2, 716], x=[717, 1], y=[1])
Batch(batch=[485], edge_index=[2, 484], x=[485, 1], y=[1])
Batch(batch=[181], edge_index=[2, 180], x=[181, 1], y=[1])
Batch(batch=[405], edge_index=[2, 404], x=[405, 1], y=[1])
Batch(batch=[379], edge_index=[2, 378], x=[379, 1], y=[1])
Batch(batch=[247], edge_index=[2, 246], x=[247, 1], y=[1])
Batch(batch=[666], edge_index=[2, 665], x=[666, 1], y=[1])
Batch(batch=[124], edge_index=[2, 123], x=[124, 1], y=[1])
Batch(batch=[280], edge_index=[2, 279], x=[280, 1], y=[1])
Batch(batch=[469], edge_index=[2, 468], x=[469, 1], y=[1])
Batch(batch=[433], edge_index=[2, 432], x=[433, 1], y=[1])
Batch(batch=[585], edge_index=[2, 584], x=[585, 1], y=[1])
Batch(batch=[188], edge_index=[2, 187], x=[188, 1], y=[1])
Batch(batch=[449], edge_index=[2, 448], x=[449, 1], y=[1

KeyboardInterrupt: 

In [6]:
tree = Tree(root='a', parent=None)
print(tree)

tree
tree.add_subtrees(Tree(root='c', parent=tree),
                 Tree(root='l', parent=tree))
print(tree)
for subtree in tree.subtrees:
    print(subtree)


Tree(root=a, parent=None, size=1)
Tree(root=a, parent=None, size=3)
Tree(root=c, parent=Tree(root=a, parent=None, size=3), size=1)
Tree(root=l, parent=Tree(root=a, parent=None, size=3), size=1)


In [83]:
distinct_features

{'!',
 '$',
 '$$',
 '%',
 '%%',
 '*',
 '*_c',
 '+',
 '++',
 '+_c',
 ',',
 '-',
 '--->',
 '-->',
 '..',
 '/\\\\',
 '1',
 '2',
 '3',
 '4',
 '<',
 '<<',
 '<<<',
 '<=',
 '<=_c',
 '<_c',
 '=',
 '==',
 '==>',
 '=_c',
 '>',
 '>=',
 '>=_c',
 '>_c',
 '?',
 '?!',
 '?0',
 '?1',
 '?10',
 '?11',
 '?2',
 '?3',
 '?4',
 '?5',
 '?6',
 '?7',
 '?8',
 '?9',
 '@',
 'A',
 "A'",
 'ABC',
 'ALL',
 'ALL2',
 'ANR',
 'APPEND',
 'AR',
 'ASSOC',
 'Arg',
 'B',
 "B'",
 'B0',
 'BIT0',
 'BIT1',
 'BOTTOM',
 'BUTLAST',
 'C',
 'CARD',
 'CASEWISE',
 'COND',
 'CONS',
 'CONSTR',
 'COUNTABLE',
 'CROSS',
 'Cx',
 'D',
 "D'",
 'DECIMAL',
 'DELETE',
 'DIFF',
 'DISJOINT',
 'DIV',
 'E',
 'EL',
 'EMPTY',
 'ENR',
 'EVEN',
 'EX',
 'EXP',
 'EXTENSIONAL',
 'F',
 'FACT',
 'FCONS',
 'FILTER',
 'FINITE',
 "FINITE'",
 'FINREC',
 'FST',
 'Fn',
 'G',
 'GABS',
 'GEN%PVAR%0',
 'GEN%PVAR%1',
 'GEQ',
 'GSPEC',
 'H',
 'HAS_SIZE',
 'HD',
 'I',
 'IMAGE',
 'IN',
 'IND_0',
 'IND_SUC',
 'INFINITE',
 'INJF',
 'INJN',
 'INL',
 'INR',
 'INSERT',
 'INTER',

In [7]:
test_theorem = "(a (c (fun (fun A (bool)) (bool)) !) (l (v A x) (a (a (c (fun A (fun A (bool))) =) (v A x)) (v A x))))"
test_theorem = test_theorem.replace('(', ' ( ')
test_theorem = test_theorem.replace(')', ' ) ')
test_theorem = test_theorem.split()

In [11]:
test_theorem = "(a (c (fun (fun A (bool)) (bool)) !) (l (v A x) (a (a (c (fun A (fun A (bool))) =) (v A x)) (v A x))))"
processed_thm_str = process_theorem(test_theorem)
tree = thm_to_tree(processed_thm_str)
print(tree)

Tree(root=(a, index=0), parent=None, size=27)


In [12]:
for subtree in tree.subtrees:
    print(subtree)
    print(subtree.view_subtree())

Tree(root=(c, index=1), parent=Tree(root=(a, index=0), parent=None, size=27), size=7)
((c, index=1), (fun, index=3), (!, index=4), (fun, index=7), (bool, index=8), (A, index=13), (bool, index=14))
Tree(root=(l, index=2), parent=Tree(root=(a, index=0), parent=None, size=27), size=19)
((l, index=2), (v, index=5), (a, index=6), (A, index=9), (x, index=10), (a, index=11), (v, index=12), (c, index=15), (v, index=16), (A, index=17), (x, index=18), (fun, index=19), (=, index=20), (A, index=21), (x, index=22), (A, index=23), (fun, index=24), (A, index=25), (bool, index=26))


In [13]:
test_theorem_2 = "(a (c (fun (fun (bool) (bool)) (bool)) !) (l (v (bool) t1) (a (c (fun (fun (bool) (bool)) (bool)) !) (l (v (bool) t2) (a (a (c (fun (bool) (fun (bool) (bool))) =) (a (a (c (fun (bool) (fun (bool) (bool))) /\\) (v (bool) t1)) (v (bool) t2))) (a (a (c (fun (bool) (fun (bool) (bool))) /\\) (v (bool) t2)) (v (bool) t1)))))))"
proc_thm_str_2 = process_theorem(test_theorem_2)
tree = thm_to_tree(proc_thm_str_2)
print(tree)

Tree(root=(a, index=0), parent=None, size=63)


In [14]:
stack = []
stack.append(tree)
while stack:
    x = stack.pop()
    print(x.root, len(x), len(x.subtrees))
    for subtree in x.subtrees[::-1]:
        stack.append(subtree)

(a, index=0) 63 2
(c, index=1) 7 2
(fun, index=3) 5 2
(fun, index=7) 3 2
(bool, index=13) 1 0
(bool, index=14) 1 0
(bool, index=8) 1 0
(!, index=4) 1 0
(l, index=2) 55 2
(v, index=5) 3 2
(bool, index=9) 1 0
(t1, index=10) 1 0
(a, index=6) 51 2
(c, index=11) 7 2
(fun, index=15) 5 2
(fun, index=19) 3 2
(bool, index=25) 1 0
(bool, index=26) 1 0
(bool, index=20) 1 0
(!, index=16) 1 0
(l, index=12) 43 2
(v, index=17) 3 2
(bool, index=21) 1 0
(t2, index=22) 1 0
(a, index=18) 39 2
(a, index=23) 23 2
(c, index=27) 7 2
(fun, index=31) 5 2
(bool, index=39) 1 0
(fun, index=40) 3 2
(bool, index=49) 1 0
(bool, index=50) 1 0
(=, index=32) 1 0
(a, index=28) 15 2
(a, index=33) 11 2
(c, index=41) 7 2
(fun, index=51) 5 2
(bool, index=57) 1 0
(fun, index=58) 3 2
(bool, index=61) 1 0
(bool, index=62) 1 0
(/\, index=52) 1 0
(v, index=42) 3 2
(bool, index=53) 1 0
(t1, index=54) 1 0
(v, index=34) 3 2
(bool, index=43) 1 0
(t2, index=44) 1 0
(a, index=24) 15 2
(a, index=29) 11 2
(c, index=35) 7 2
(fun, index=4

In [16]:
theorem = '(a A (fun A B))'
proc_thm_str = process_theorem(theorem)
tree = thm_to_tree(proc_thm_str)

datum = graph_to_data(tree)
datum

Data(edge_index=[2, 4], x=[5, 1])