# Parser 

In [None]:
#| default_exp parser

In [None]:
#| export
import CodeSyntaxConcept

from CodeSyntaxConcept.tokenizer import CodeTokenizer, get_token_type
import CodeSyntaxConcept.utils as utils
import pandas as pd

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
class TreeSitterParser:
    """Class to parse source code snippets using the Tree-sitter grammar"""
    def __init__(self, tokenizer: CodeTokenizer):
        self.tokenizer = tokenizer

    def process_source_code(self, 
                            source_code: str #Source code to parse
                            ):
        """"Process the given source code snippet and gets the ast node types"""
        ast_representation = self.tokenizer.parser.parse(bytes(source_code, "utf8"))
        ast_nodes = []
        utils.traverse(ast_representation.root_node, ast_nodes)
        source_code_ast_types = []
        for node_index, node in enumerate(ast_nodes):
            #source_code_ast_types.append([node.text.decode("utf-8"), self.tokenizer.node_types.index(node.type), self.tokenizer.node_types.index(node.parent.type)])
            source_code_ast_types.append((node.text.decode("utf-8"), node.type, node.parent.type))
        return source_code_ast_types

    def process_model_source_code(self, 
                                  source_code: str # Source code to process
                                  ):
        """Process the given source code snippet and gets the encodings and ast types"""
        source_code_encoding = self.tokenizer(source_code)
        source_code_ast_types = []
        for input_id_index, input_id in enumerate(source_code_encoding['input_ids']):
            #source_code_ast_types.append([input_id, source_code_encoding['ast_ids'][input_id_index], source_code_encoding['parent_ast_ids'][input_id_index]])
            source_code_ast_types.append((
                input_id, 
                'ERROR' if source_code_encoding['ast_ids'][input_id_index] == -1 else self.tokenizer.node_types[source_code_encoding['ast_ids'][input_id_index]], 
                'ERROR' if source_code_encoding['ast_ids'][input_id_index] == -1 else self.tokenizer.node_types[source_code_encoding['parent_ast_ids'][input_id_index]]))
        return source_code_encoding, source_code_ast_types

# Testing

In [None]:
#| hide
#| eval: false
checkpoint = "EleutherAI/gpt-neo-125M"
#checkpoint = "EleutherAI/gpt-neo-1.3B"

tokenizer = CodeTokenizer.from_pretrained(checkpoint, "python")
parser = TreeSitterParser(tokenizer)

source_code = "def multiply_numbers(a,b):\n    return a*b"

In [None]:
#| hide
#| eval: false
print(parser.process_source_code(source_code))

[('def', 'def', 'function_definition'), ('multiply_numbers', 'identifier', 'function_definition'), ('(', '(', 'parameters'), ('a', 'identifier', 'parameters'), (',', ',', 'parameters'), ('b', 'identifier', 'parameters'), (')', ')', 'parameters'), (':', ':', 'function_definition'), ('return', 'return', 'return_statement'), ('a', 'identifier', 'binary_operator'), ('*', '*', 'binary_operator'), ('b', 'identifier', 'binary_operator')]


In [None]:
#| hide
#| eval: false
source_code_encoding, source_code_dataframe = parser.process_model_source_code(source_code)
print(source_code_encoding['input_ids'])
print(source_code_dataframe)

[4299, 29162, 62, 77, 17024, 7, 64, 11, 65, 2599, 198, 220, 220, 220, 1441, 257, 9, 65]
[(4299, 'def', 'function_definition'), (29162, 'identifier', 'function_definition'), (62, 'identifier', 'function_definition'), (77, 'identifier', 'function_definition'), (17024, 'identifier', 'function_definition'), (7, '(', 'parameters'), (64, 'identifier', 'parameters'), (11, ',', 'parameters'), (65, 'identifier', 'parameters'), (2599, ')', 'parameters'), (198, 'ERROR', 'ERROR'), (220, 'ERROR', 'ERROR'), (220, 'ERROR', 'ERROR'), (220, 'ERROR', 'ERROR'), (1441, 'return', 'return_statement'), (257, 'identifier', 'binary_operator'), (9, '*', 'binary_operator'), (65, 'identifier', 'binary_operator')]
