# CODE MODULE

In [4]:
#| default_exp tokenizer

In [10]:
#| export
import CodeCheckList
import CodeCheckList.utils as utils

import json
from transformers import AutoTokenizer
from tree_sitter import Language, Parser

In [6]:
#| hide
from nbdev.showdoc import *

In [11]:
#| export
def get_token_type(
    tok_span: tuple, # (start, end) position of a token in tokenizer
    nodes: list,     # list of tree-sitter nodes
    lines: list,     # list of lines in the code
) -> tuple: # (parent_type, token_type) of the token
    """Get the parent AST type and token AST type of a token."""
    def get_node_span(node):
        start_span = utils.convert_to_offset(node.start_point, lines)
        end_span = utils.convert_to_offset(node.end_point, lines)
        return start_span, end_span
    
    node_spans = [get_node_span(node) for node in nodes]
    for i, span in enumerate(node_spans):
        if (span[0] <= tok_span[0] and tok_span[0] < span[1]) or (span[0] < tok_span[1] and tok_span[1] <= span[1]):
            return nodes[i].parent.type, nodes[i].type

In [12]:
#| export
class CodeTokenizer():
    """A tokenizer for code, which aligns the tokens with the AST nodes."""
    def __init__(self, tokenizer, parser, node_types):
        self.tokenizer = tokenizer
        self.parser = parser
        self.node_types = node_types
    
    def __call__(self, code):
        encoding = self.tokenizer(code, return_offsets_mapping=True)
        tree = self.parser.parse(bytes(code, "utf8"))
        nodes = []
        utils.traverse(tree.root_node, nodes)

        encoding["ast_ids"] = []
        encoding["parent_ast_ids"] = []
        for i, (start, end) in enumerate(encoding.offset_mapping):
            if encoding["input_ids"][i] in self.tokenizer.all_special_ids:
                encoding["ast_ids"].append(-1)
                encoding["parent_ast_ids"].append(-1)
                continue
            if start == None or end == None:
                encoding["ast_ids"].append(-1)
                encoding["parent_ast_ids"].append(-1)
                continue
            type_info = get_token_type((start, end), nodes, code.split("\n"))
            if type_info is None:
                encoding["ast_ids"].append(-1)
                encoding["parent_ast_ids"].append(-1)
            else:
                parent_node_type, node_type = type_info
                try:
                    encoding["ast_ids"].append(self.node_types.index(node_type))
                    encoding["parent_ast_ids"].append(self.node_types.index(parent_node_type))
                except Exception as e:
                    print(type_info)
                    print(code)
                    print(self.tokenizer.decode(encoding["input_ids"][i]))
                    encoding["ast_ids"].append(-1)
                    encoding["parent_ast_ids"].append(-1)
                    raise e
            
        return encoding

    @staticmethod
    def from_pretrained(
        name_or_path: str,  # name or path of the tokenizer
        lang: str,          # language of the tokenizer
    ):                      # CodeTokenizer for the given language
        """Create a CodeTokenizer from a pretrained tokenizer for a given language."""
        tokenizer = AutoTokenizer.from_pretrained(name_or_path)

        # Grab the node types from the tree-sitter language
        language = Language(f"{CodeCheckList.__path__[0]}/grammars/tree-sitter-languages.so", lang)
        node_path = f"{CodeCheckList.__path__[0]}/grammars/tree-sitter-{lang}/src/node-types.json"
        with open(node_path) as f:
            node_types = json.load(f)
        node_types = utils.unroll_node_types(node_types)

        # Create a parser for the language
        parser = Parser()
        parser.set_language(language)
        
        return CodeTokenizer(tokenizer, parser, node_types)

In [13]:
# test the tokenizer
py_tokenizer = CodeTokenizer.from_pretrained("gpt2", "python")
code = "def foo():\n    print('hello world')"

encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])

print(encoding)
print(py_tokenizer.tokenizer.convert_ids_to_tokens(encoding["input_ids"]))

{'input_ids': [4299, 22944, 33529, 198, 220, 220, 220, 3601, 10786, 31373, 995, 11537], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 3), (3, 7), (7, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 20), (20, 22), (22, 27), (27, 33), (33, 35)], 'ast_ids': [136, 99, 10, -1, -1, -1, -1, 99, 10, 89, 89, 89], 'parent_ast_ids': [69, 69, 115, -1, -1, -1, -1, 114, 101, 101, 101, 101]}
['def', 'Ġfoo', '():', 'Ċ', 'Ġ', 'Ġ', 'Ġ', 'Ġprint', "('", 'hello', 'Ġworld', "')"]
