In [None]:
from typing import Optional
import ndjson
import pickle

class Node:
    def __init__(self, b_i: Optional[int], kind: str, code_pos: str, data: str):
        self.branching_idx = b_i
        self.parent = None
        self.children = []
        self.kind = kind
        self.code_pos = code_pos
        self.data = data

    def set_parent(self, parent: 'Node'):
        self.parent = parent

    def add_child(self, child: 'Node'):
        self.children.append(child)

    def to_dict(self):
        """Convert the node and its children to a dictionary."""
        return {
            'kind': self.kind,
            'code_pos': self.code_pos,
            'data': self.data,
            'children': [child.to_dict() for child in self.children]
        }

def json_to_tree(data: dict) -> Node:
    """
    Recursively builds a tree of Node objects from a JSON dictionary.
    """
    node = Node(
        b_i=None,
        kind=data.get('kind'),
        code_pos=data.get('code_pos'),
        data=data.get('data')
    )


    # Recursively add children
    for child_data in data.get('children', []):
        child_node = json_to_tree(child_data)
        child_node.set_parent(node)  # Set the parent for the child node
        node.add_child(child_node)

    return node

#NODE TO NODE PATHS
# Function to collect all leaf nodes iteratively using DFS
def collect_leaves_iterative(root):
    if root is None:
        return []


    stack = [(root, [])]  # Stack to store (node, path_from_root)
    leaves = []  # List to store leaf nodes and their paths

    while stack:
        node, path = stack.pop()
        current_path = path + [node.kind]  # Update the current path

        # leaf node - has no children
        if not node.children:
            leaves.append((node, current_path))

        # push the children to the stack for DFS
        children = reversed(node.children)
        for child in children:  # process children in order on the stack
            stack.append((child, current_path))

    return leaves


# Function to find the Lowest Common Ancestor (LCA) iteratively
def find_lca_iterative(n1_path, n2_path):
    length = len(n1_path) if len(n1_path) < len(n2_path) else len(n2_path)

    lca = None
    for i in range(length):
        if n1_path[i] == n2_path[i]:
            lca = n1_path[i]
        else:
            break
    return lca


def find_leaf_to_leaf_paths_iterative(root):
    leaf_nodes = collect_leaves_iterative(root)

    #list of all leaf-to-leaf paths
    leaf_to_leaf_paths = []

    # Iterate over each pair of leaf nodes
    for i in range(len(leaf_nodes)):
        for j in range(i + 1, len(leaf_nodes)):
            leaf1, path1 = leaf_nodes[i]
            leaf2, path2 = leaf_nodes[j]

            # find lca
            lca = find_lca_iterative(path1, path2)

            # find the indexes
            lca_index1 = path1.index(lca)
            lca_index2 = path2.index(lca)

            # Path from leaf1 to leaf2 via the LCA
            path_to_lca_from_leaf1 = path1[:lca_index1 + 1]
            path_to_lca_from_leaf2 = path2[:lca_index2 + 1]
            path_to_lca_from_leaf2.reverse()

            #combine the paths
            complete_path = path_to_lca_from_leaf1 + path_to_lca_from_leaf2[1:]

            # Add the complete leaf-to-leaf path to the result
            leaf_to_leaf_paths.append((leaf1.data,)+tuple(complete_path)+(leaf2.data,))


    return [node.data for node,path in leaf_nodes], leaf_to_leaf_paths

def find_tag(root) -> str:
    # root is FunctionDefinition
    definition_node = root
    for definition_child in definition_node.children:
        if definition_child.kind == "FunctionDeclarator":
            declarator_node = definition_child
            for declarator_child in declarator_node.children:
                if declarator_child.kind == "IdentifierDeclarator":
                    return str(declarator_child.data)


def generate_vocabs(file_paths):
    # Open the .ndjson file
        # Initialize empty sets
    value_vocab = set()  # Set of all leaf values
    path_vocab = set()   # Set of all distinct paths
    tags_vocab = set()   # Set of all distinct function tags
    max_num_contexts = 0
    for path in file_paths:
        with open(path, 'r') as ndjson_file:
            # Load the file content
            data = ndjson.load(ndjson_file)
            
            for function_json in data:
                # Convert each line (function) to a tree
                func_root = json_to_tree(function_json)
                tag = find_tag(func_root)
                func_values, func_paths = find_leaf_to_leaf_paths_iterative(func_root)
                max_num_contexts = max(len(func_paths), max_num_contexts)
                
                # Update vocabularies
                value_vocab.update(func_values)  # Add function's values to value_vocab set
                
                # Convert each list in func_paths to a tuple before updating the set
                path_vocab.update(path[1:-1] for path in func_paths)  # Add function's paths to path_vocab set
                
                tags_vocab.add(tag)  # add function's tag to tags_vocab set
    
    # create dictionaries from the sets by assigning each value an index
    value_vocab_dict = {value: idx+1 for idx, value in enumerate(sorted(value_vocab))}
    path_vocab_dict = {path: idx+1 for idx, path in enumerate(sorted(path_vocab))}
    tags_vocab_dict = {tag: idx+1 for idx, tag in enumerate(sorted(tags_vocab))}

    # Append the padding values to the dictionaries
    value_vocab_dict['<PAD>'] = 0
    path_vocab_dict[('<PAD>',)] = 0
    tags_vocab_dict['<PAD>'] = 0

    #combine
    vocabs_dict = {
        'value_vocab': value_vocab_dict,
        'path_vocab': path_vocab_dict,
        'tags_vocab': tags_vocab_dict,
        'max_num_contexts': max_num_contexts
    }

    return vocabs_dict

vocabs_json = 'vocabs.pkl'
train = 'strat_train_functionsASTs.ndjson'
valid = 'strat_validate_functionsASTs.ndjson'

print("Started generating vocabs...")
vocabs_dict = generate_vocabs([train, valid])
with open(vocabs_json, 'wb') as f:
    pickle.dump(vocabs_dict, f)

print("Done.")