# Parser 

In [61]:
#| default_exp parser

In [62]:
#| export
import CodeSyntaxConcept

from CodeSyntaxConcept.tokenizer import CodeTokenizer, get_token_type
import CodeSyntaxConcept.utils as utils
import pandas as pd

In [63]:
#| hide
from nbdev.showdoc import *

In [80]:
#| export
class TreeSitterParser:
    
    def __init__(self, tokenizer: CodeTokenizer):
        self.tokenizer = tokenizer

    def process_source_code(self,source_code: str):
        ast_representation = self.tokenizer.parser.parse(bytes(source_code, "utf8"))
        ast_nodes = []
        utils.traverse(ast_representation.root_node, ast_nodes)
        source_code_ast_types = []
        for node_index, node in enumerate(ast_nodes):
            #source_code_ast_types.append([node.text.decode("utf-8"), self.tokenizer.node_types.index(node.type), self.tokenizer.node_types.index(node.parent.type)])
            source_code_ast_types.append((node.text.decode("utf-8"), node.type, node.parent.type))
        return pd.DataFrame(source_code_ast_types, columns=['input', 'ast_concept', 'parent_ast_concept'])

    def process_model_source_code(self, source_code: str):
        source_code_encoding = self.tokenizer(source_code)
        source_code_ast_types = []
        for input_id_index, input_id in enumerate(source_code_encoding['input_ids']):
            #source_code_ast_types.append([input_id, source_code_encoding['ast_ids'][input_id_index], source_code_encoding['parent_ast_ids'][input_id_index]])
            source_code_ast_types.append((
                input_id, 
                'ERROR' if source_code_encoding['ast_ids'][input_id_index] == -1 else self.tokenizer.node_types[source_code_encoding['ast_ids'][input_id_index]], 
                'ERROR' if source_code_encoding['ast_ids'][input_id_index] == -1 else self.tokenizer.node_types[source_code_encoding['parent_ast_ids'][input_id_index]]))
        return source_code_encoding, pd.DataFrame(source_code_ast_types, columns=['input_id', 'ast_concept', 'parent_ast_concept'])

# Testing

In [81]:
checkpoint = "EleutherAI/gpt-neo-125M"
#checkpoint = "EleutherAI/gpt-neo-1.3B"

tokenizer = CodeTokenizer.from_pretrained(checkpoint, "python")
parser = TreeSitterParser(tokenizer)

source_code = "def multiply_numbers(a,b):\n    return a*b"

In [82]:
print(parser.process_source_code(source_code))

               input ast_concept   parent_ast_concept
0                def         def  function_definition
1   multiply_numbers  identifier  function_definition
2                  (           (           parameters
3                  a  identifier           parameters
4                  ,           ,           parameters
5                  b  identifier           parameters
6                  )           )           parameters
7                  :           :  function_definition
8             return      return     return_statement
9                  a  identifier      binary_operator
10                 *           *      binary_operator
11                 b  identifier      binary_operator


In [85]:
source_code_encoding, source_code_dataframe = parser.process_model_source_code(source_code)
print(source_code_encoding['input_ids'])
print(source_code_dataframe)

[4299, 29162, 62, 77, 17024, 7, 64, 11, 65, 2599, 198, 220, 220, 220, 1441, 257, 9, 65]
    input_id ast_concept   parent_ast_concept
0       4299         def  function_definition
1      29162  identifier  function_definition
2         62  identifier  function_definition
3         77  identifier  function_definition
4      17024  identifier  function_definition
5          7           (           parameters
6         64  identifier           parameters
7         11           ,           parameters
8         65  identifier           parameters
9       2599           )           parameters
10       198       ERROR                ERROR
11       220       ERROR                ERROR
12       220       ERROR                ERROR
13       220       ERROR                ERROR
14      1441      return     return_statement
15       257  identifier      binary_operator
16         9           *      binary_operator
17        65  identifier      binary_operator
