# Core Module Tokenizer

In [1]:
#| default_exp tokenizer

In [2]:
#| export
import CodeSyntaxConcept
import CodeSyntaxConcept.utils as utils

import json
from transformers import AutoTokenizer
from tree_sitter import Language, Parser

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#| hide
from nbdev.showdoc import *

In [4]:
#| export
def get_token_type(
    tok_span: tuple, # (start, end) position of a token in tokenizer
    nodes: list,     # list of tree-sitter nodes
    lines: list,     # list of lines in the code
) -> tuple: # (parent_type, token_type) of the token
    """Get the parent AST type and token AST type of a token."""
    def get_node_span(node):
        start_span = utils.convert_to_offset(node.start_point, lines)
        end_span = utils.convert_to_offset(node.end_point, lines)
        return start_span, end_span
    
    node_spans = [get_node_span(node) for node in nodes]
    for i, span in enumerate(node_spans):
        if (span[0] <= tok_span[0] and tok_span[0] < span[1]) or (span[0] < tok_span[1] and tok_span[1] <= span[1]):
            return nodes[i].parent.type, nodes[i].type

In [5]:
#| export
class CodeTokenizer():
    """A tokenizer for code, which aligns the tokens with the AST nodes."""
    def __init__(self, tokenizer, parser, node_types):
        self.tokenizer = tokenizer
        self.parser = parser
        self.node_types = node_types
    
    def __call__(self, code):
        encoding = self.tokenizer(code, return_offsets_mapping=True)
        tree = self.parser.parse(bytes(code, "utf8"))
        nodes = []
        utils.traverse(tree.root_node, nodes)

        encoding["ast_ids"] = []
        encoding["parent_ast_ids"] = []
        for i, (start, end) in enumerate(encoding.offset_mapping):
            if encoding["input_ids"][i] in self.tokenizer.all_special_ids:
                encoding["ast_ids"].append(-1)
                encoding["parent_ast_ids"].append(-1)
                continue
            if start == None or end == None:
                encoding["ast_ids"].append(-1)
                encoding["parent_ast_ids"].append(-1)
                continue
            type_info = get_token_type((start, end), nodes, code.split("\n"))
            if type_info is None:
                encoding["ast_ids"].append(-1)
                encoding["parent_ast_ids"].append(-1)
            else:
                parent_node_type, node_type = type_info
                try:
                    encoding["ast_ids"].append(self.node_types.index(node_type))
                    encoding["parent_ast_ids"].append(self.node_types.index(parent_node_type))
                except Exception as e:
                    print(type_info)
                    print(code)
                    print(self.tokenizer.decode(encoding["input_ids"][i]))
                    encoding["ast_ids"].append(-1)
                    encoding["parent_ast_ids"].append(-1)
                    raise e
            
        return encoding

    @staticmethod
    def from_pretrained(
        name_or_path: str,  # name or path of the tokenizer
        lang: str,          # language of the tokenizer
    ):                      # CodeTokenizer for the given language
        """Create a CodeTokenizer from a pretrained tokenizer for a given language."""
        tokenizer = AutoTokenizer.from_pretrained(name_or_path)

        # Grab the node types from the tree-sitter language
        language = Language(f"{CodeSyntaxConcept.__path__[0]}/grammars/tree-sitter-languages.so", lang)
        node_path = f"{CodeSyntaxConcept.__path__[0]}/grammars/tree-sitter-{lang}/src/node-types.json"
        with open(node_path) as f:
            node_types = json.load(f)
        node_types = utils.unroll_node_types(node_types)

        # Create a parser for the language
        parser = Parser()
        parser.set_language(language)
        
        return CodeTokenizer(tokenizer, parser, node_types)

# Testing

In [6]:
from CodeSyntaxConcept import loader

"""define language"""
python_language = "python"

languages = [python_language]

loader.download_grammars(languages)

/home/svelascodimate/miniconda3/envs/code-syntax-concept/lib/python3.10/site-packages/CodeSyntaxConcept/grammars


In [7]:
# test the tokenizer
py_tokenizer = CodeTokenizer.from_pretrained("gpt2", "python")
code = "def foo():\n    print('hello world')"
#code =  "def reserve_free_tcp_port(self):\n        \"\"\"\"\"\"\n        Reserve free TCP port to be used by Cloud SQL Proxy\n        \"\"\"\"\"\"\n        self.reserved_tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n        self.reserved_tcp_socket.bind(('127.0.0.1', 0))\n        self.sql_proxy_tcp_port = self.reserved_tcp_socket.getsockname()[1]"

#print(py_tokenizer.node_types)

encoding = py_tokenizer(code)

assert "ast_ids" in encoding
assert "parent_ast_ids" in encoding
assert len(encoding["ast_ids"]) == len(encoding["input_ids"])
assert len(encoding["parent_ast_ids"]) == len(encoding["input_ids"])

print(encoding)
print(py_tokenizer.tokenizer.convert_ids_to_tokens(encoding["input_ids"]))

('ERROR', 'identifier')
def reserve_free_tcp_port(self):
        """"""
        Reserve free TCP port to be used by Cloud SQL Proxy
        """"""
        self.reserved_tcp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.reserved_tcp_socket.bind(('127.0.0.1', 0))
        self.sql_proxy_tcp_port = self.reserved_tcp_socket.getsockname()[1]
 Reserve


ValueError: 'ERROR' is not in list