In [22]:
from tree_sitter import Language, Parser
import javalang
import numpy as np 

Language.build_library(
	# Store the library in the `build` directory
	'build/my-languages.so',
	
	# Include one or more languages
	[
		'/Users/jirigesi/Documents/tree-sitter-java'
	]
)

JAVA_LANGUAGE = Language('build/my-languages.so', 'java')
parser = Parser()

parser.set_language(JAVA_LANGUAGE)

In [2]:
def traverse(node,depth=0):
    declaration = {}
    stack = []
    stack.append(node)
    while stack:
        node = stack.pop()
        if ('declaration' in node.type and node.type != "local_variable_declaration") or 'if_statement' in node.type or 'else' in node.type:
            data = code[node.start_byte:node.end_byte].split('{')[0].strip().split(' ')
            if node.type in declaration:
                declaration[node.type].append(data)
            else:
                declaration[node.type] = [data]
        for child in node.children:
            stack.append(child)
    return declaration

def label_tokens(token_list, declaration):
    types = [] 
    for token in token_list:
        flag = False
        for key in declaration:
            for value in declaration[key]:
                if token in value:
                    types.append(key)
                    flag = True
                    break
            if flag:
                break
        if not flag:
            types.append("other")
    return types

In [39]:
def get_extended_types(token_list, types):
    tree = list(javalang.tokenizer.tokenize(" ".join(token_list)))
    code = ' '.join(token_list)
    right = 0
    left = 0
    postion_mapping = [] 

    while right < len(code):
        if code[right] == ' ':
            postion_mapping.append((left, right))
            left = right + 1
        right += 1

    # add the last token
    postion_mapping.append((left, right))
    code = ["<s>"]
    extended_types = []
    left = 0
    for node in tree:
        # rewrite code
        node = str(node).split(' ')
        if node[1] == '"MASK"':
            code.append('<mask>')
        else:
            code.append(node[1][1:-1])
        # extend types
        left = int(node[-1]) -1
        right = left + len(node[1][1:-1])
        # check (left, right) in postion_mapping and get the index
        for i in range(len(postion_mapping)):
            if left >= postion_mapping[i][0] and right <= postion_mapping[i][1]:
                extended_types.append([types[i], node[1]])
                break
    code.append("</s>")
    return extended_types, ' '.join(code)

In [40]:
def get_ast_types(code):
    code = code.replace("{", " {")
    code = " ".join(code.split())
    code_list = code.split(' ')
    tree = parser.parse(bytes(code, "utf8"))
    root_node = tree.root_node
    declaration = traverse(root_node)
    types = label_tokens(code_list, declaration)

    ast_types, rewrote_code = get_extended_types(code_list, types)
    return ast_types, rewrote_code

code = "class Simple{ public static void main(String args[]){ System.out.println( 'Hello Java'); }}"
ast_types, rewrote_code = get_ast_types(code)

In [41]:
len(ast_types), len(rewrote_code.split(' '))

(25, 27)

In [5]:
# code = "class Simple{ public static void main(String args[]){ System.out.println( 'Hello Java'); }}"

# code = code.replace("{", " {")
# code = " ".join(code.split())
# code_list = code.split(' ')

# tree = parser.parse(bytes(code, "utf8"))

# root_node = tree.root_node
# # declaration = {}
# declaration = traverse(root_node)
# print(declaration)
# types = label_tokens(code_list, declaration)
# print(types)

# if len(types) != len(code_list):
#     print("Error: the number of tokens is not equal to the number of labels")

# ast_types = get_extended_types(code_list, types)

{'class_declaration': [['class', 'Simple']], 'method_declaration': [['public', 'static', 'void', 'main(String', 'args[])']]}
['class_declaration', 'class_declaration', 'other', 'method_declaration', 'method_declaration', 'method_declaration', 'method_declaration', 'method_declaration', 'other', 'other', 'other', 'other', 'other']


In [6]:
ast_types

[['class_declaration', '"class"'],
 ['class_declaration', '"Simple"'],
 ['other', '"{"'],
 ['method_declaration', '"public"'],
 ['method_declaration', '"static"'],
 ['method_declaration', '"void"'],
 ['method_declaration', '"main"'],
 ['method_declaration', '"("'],
 ['method_declaration', '"String"'],
 ['method_declaration', '"args"'],
 ['method_declaration', '"["'],
 ['method_declaration', '"]"'],
 ['method_declaration', '")"'],
 ['other', '"{"'],
 ['other', '"System"'],
 ['other', '"."'],
 ['other', '"out"'],
 ['other', '"."'],
 ['other', '"println"'],
 ['other', '"("'],
 ['other', '"\'Hello'],
 ['other', '")"'],
 ['other', '";"'],
 ['other', '"}"'],
 ['other', '"}"']]

In [7]:
import json

file_path = "../dataset/valid.txt"
postfix=file_path.split('/')[-1].split('.txt')[0]
index_filename=file_path
url_to_code={}
with open('/'.join(index_filename.split('/')[:-1])+'/data.jsonl') as f:
    for line in f:
        line=line.strip()
        js=json.loads(line)
        url_to_code[js['idx']]=js['func']
data=[]
cache={}
f=open(index_filename)
with open(index_filename) as f:
    lines = 1000
    added_lines = 0
    for line in f:
        # control number of read data 
        if added_lines >= lines:
            break
        line=line.strip()
        url1,url2,label=line.split('\t')
        if url1 not in url_to_code or url2 not in url_to_code:
            continue
        if label=='0':
            label=0
        else:
            label=1
        data.append((url1,url2,label,' '.join(url_to_code[url1].split()), ' '.join(url_to_code[url2].split())))
        added_lines += 1

In [8]:
len(data)

1000

In [19]:
def convert_types(types):
    # check the index of first second value is the "{"
    if types[0][1] == '"class"':
        return ['[CLS]'] + types + ['[SEP]']
    index_ = 0
    # if not class declaration, find the first "{" and add method_declaration before it
    for i in range(len(types)):
        if types[i][1] == '"{"':
            index_ = i
            break
    final_types = [] 
    final_types.append('[CLS]')
    for i in range(len(types)):
        if i < index_:
            final_types.append("method_declaration")
        else:
            final_types.append(types[i][0])
    final_types.append('[SEP]')

    return final_types

In [13]:
def get_syntax_types_for_code(code_snippet):
  types = ["[CLS]"]
  code = ["<s>"]
  tree = list(javalang.tokenizer.tokenize(code_snippet))
  
  for i in tree:
    j = str(i)
    j = j.split(" ")
    if j[1] == '"MASK"':
      types.append('[MASK]')
      code.append('<mask>')
    else:
      types.append(j[0].lower())
      code.append(j[1][1:-1])
    
  types.append("[SEP]")
  code.append("</s>")
  return np.array(types), ' '.join(code)

In [34]:
code = data[50][4]

for i in range(len(data)):
    code = data[i][3]
    types = get_ast_types(code)
    final_types = convert(types)
    types_1, rewrote_code_1 = get_syntax_types_for_code(code)

    if len(final_types) != len(types_1):
        print("Error")

In [35]:
final_types

['[CLS]',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'method_declaration',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 'other',
 '[SEP]']

In [36]:
# get unique types 
unique = []

for i in range(len(data)):
    code = data[i][3]
    types = get_ast_types(code)
    final_types = convert(types)

    for j in range(len(final_types)):
        if final_types[j] not in unique:
            unique.append(final_types[j])

In [37]:
unique

['[CLS]',
 'method_declaration',
 'other',
 '[SEP]',
 'if_statement',
 'else',
 'field_declaration',
 'class_declaration',
 'constructor_declaration']

In [None]:
ast_syntaxes = ['method_declaration', 'if_statement', 'else', 'class_declaration', 'constructor_declaration']

In [43]:
[1] + [2, 3, 4] + [5]

[1, 2, 3, 4, 5]