In [1]:
from tree_sitter import Language, Parser
import javalang

Language.build_library(
	# Store the library in the `build` directory
	'build/my-languages.so',
	
	# Include one or more languages
	[
		'/Users/jirigesi/Documents/tree-sitter-java'
	]
)

JAVA_LANGUAGE = Language('build/my-languages.so', 'java')
parser = Parser()

parser.set_language(JAVA_LANGUAGE)

example = "class Simple{ public static void main(String args[]){ System.out.println( 'Hello Java'); }}"

example = " ".join(example.split())
code = example.replace("{", " {").strip()
token_list = example.split(' ')

tree = parser.parse(bytes(example, "utf8"))

root_node = tree.root_node

In [2]:
declaration = {}

def traverse(node, depth=0):
    if 'declaration' in node.type:
        data = example[node.start_byte:node.end_byte].split('{')[0].strip().split(' ')
        
        if node.type in declaration:
            declaration[node.type].append(data)
        else:
            declaration[node.type] = [data]
    for child in node.children:
        traverse(child, depth + 1)

def label_tokens(token_list, declaration):
    types = [] 
    for token in token_list:
        flag = False
        for key in declaration:
            for value in declaration[key]:
                if token in value:
                    types.append(key)
                    flag = True
                    break
            if flag:
                break
        if not flag:
            types.append("other")
    return types

In [3]:
traverse(root_node)
types = label_tokens(token_list, declaration)

In [4]:
def get_position_mapping(code):
    right = 0
    left = 0
    postion_mapping = [] 
    while right < len(code):
        if code[right] == ' ':
            postion_mapping.append((left, right))
            left = right + 1
        right += 1
    # add the last token
    postion_mapping.append((left, right))
    return postion_mapping

In [5]:
postion_mapping = get_position_mapping(code)

In [6]:
code

"class Simple { public static void main(String args[]) { System.out.println( 'Hello Java'); }}"

In [7]:
len(code)

93

In [8]:
len(postion_mapping), len(token_list), len(types)

(13, 11, 11)

In [9]:
postion_mapping

[(0, 5),
 (6, 12),
 (13, 14),
 (15, 21),
 (22, 28),
 (29, 33),
 (34, 45),
 (46, 53),
 (54, 55),
 (56, 75),
 (76, 82),
 (83, 90),
 (91, 93)]

In [10]:
def get_extended_types(token_list, types):
    tree = list(javalang.tokenizer.tokenize(" ".join(token_list)))
    code = ' '.join(token_list)
    right = 0
    left = 0

    postion_mapping = [] 

    while right < len(code):
        if code[right] == ' ':
            postion_mapping.append((left, right))
            left = right + 1
        right += 1

    # add the last token
    postion_mapping.append((left, right))
    
    extended_types = []
    left = 0

    for node in tree:
        node = str(node).split(' ')
        left = int(node[-1]) -1
        right = left + len(node[1][1:-1])
        # check (left, right) in postion_mapping and get the index
        for i in range(len(postion_mapping)):
            if left >= postion_mapping[i][0] and right <= postion_mapping[i][1]:
                extended_types.append([types[i], node[1]])
                break
    return extended_types

In [11]:
extended_types = get_extended_types(token_list, types)

In [12]:
extended_types

[['class_declaration', '"class"'],
 ['other', '"Simple"'],
 ['other', '"{"'],
 ['method_declaration', '"public"'],
 ['method_declaration', '"static"'],
 ['method_declaration', '"void"'],
 ['method_declaration', '"main"'],
 ['method_declaration', '"("'],
 ['method_declaration', '"String"'],
 ['other', '"args"'],
 ['other', '"["'],
 ['other', '"]"'],
 ['other', '")"'],
 ['other', '"{"'],
 ['other', '"System"'],
 ['other', '"."'],
 ['other', '"out"'],
 ['other', '"."'],
 ['other', '"println"'],
 ['other', '"("'],
 ['other', '"\'Hello'],
 ['other', '")"'],
 ['other', '";"'],
 ['other', '"}"'],
 ['other', '"}"']]