In [1]:
from tree_sitter import Language, Parser
import javalang

Language.build_library(
	# Store the library in the `build` directory
	'build/my-languages.so',
	
	# Include one or more languages
	[
		'/Users/jirigesi/Documents/tree-sitter-java'
	]
)

JAVA_LANGUAGE = Language('build/my-languages.so', 'java')
parser = Parser()

parser.set_language(JAVA_LANGUAGE)

In [37]:
def traverse(node,depth=0):
    declaration = {}
    stack = []
    stack.append(node)
    while stack:
        node = stack.pop()
        if 'declaration' in node.type or 'if_statement' in node.type or 'else' in node.type:
            data = code[node.start_byte:node.end_byte].split('{')[0].strip().split(' ')
            if node.type in declaration:
                declaration[node.type].append(data)
            else:
                declaration[node.type] = [data]
        for child in node.children:
            stack.append(child)
    return declaration

def label_tokens(token_list, declaration):
    types = [] 
    for token in token_list:
        flag = False
        for key in declaration:
            for value in declaration[key]:
                if token in value:
                    types.append(key)
                    flag = True
                    break
            if flag:
                break
        if not flag:
            types.append("other")
    return types

In [38]:
def get_extended_types(token_list, types):
    tree = list(javalang.tokenizer.tokenize(" ".join(token_list)))
    code = ' '.join(token_list)
    right = 0
    left = 0

    postion_mapping = [] 

    while right < len(code):
        if code[right] == ' ':
            postion_mapping.append((left, right))
            left = right + 1
        right += 1

    # add the last token
    postion_mapping.append((left, right))
    
    extended_types = []
    left = 0

    for node in tree:
        node = str(node).split(' ')
        left = int(node[-1]) -1
        right = left + len(node[1][1:-1])
        # check (left, right) in postion_mapping and get the index
        for i in range(len(postion_mapping)):
            if left >= postion_mapping[i][0] and right <= postion_mapping[i][1]:
                extended_types.append([types[i], node[1]])
                break
    return extended_types

In [8]:
code = "class Simple{ public static void main(String args[]){ System.out.println( 'Hello Java'); }}"

code = code.replace("{", " {")
code = " ".join(code.split())
code_list = code.split(' ')

tree = parser.parse(bytes(code, "utf8"))

root_node = tree.root_node
# declaration = {}
declaration = traverse2(root_node)
print(declaration)
types = label_tokens(code_list, declaration)
print(types)

if len(types) != len(code_list):
    print("Error: the number of tokens is not equal to the number of labels")

ast_types = get_extended_types(code_list, types)

{'class_declaration': [['class', 'Simple']], 'method_declaration': [['public', 'static', 'void', 'main(String', 'args[])']]}
['class_declaration', 'class_declaration', 'other', 'method_declaration', 'method_declaration', 'method_declaration', 'method_declaration', 'method_declaration', 'other', 'other', 'other', 'other', 'other']


In [39]:
def get_ast_types(code):
    code = code.replace("{", " {")
    code = " ".join(code.split())
    code_list = code.split(' ')
    tree = parser.parse(bytes(code, "utf8"))
    root_node = tree.root_node
    declaration = traverse(root_node)
    types = label_tokens(code_list, declaration)

    ast_types = get_extended_types(code_list, types)
    return ast_types

code = "class Simple{ public static void main(String args[]){ System.out.println( 'Hello Java'); }}"
ast_types = get_ast_types(code)

In [40]:
ast_types

[['class_declaration', '"class"'],
 ['class_declaration', '"Simple"'],
 ['other', '"{"'],
 ['other', '"public"'],
 ['method_declaration', '"static"'],
 ['method_declaration', '"void"'],
 ['method_declaration', '"main"'],
 ['method_declaration', '"("'],
 ['method_declaration', '"String"'],
 ['method_declaration', '"args"'],
 ['method_declaration', '"["'],
 ['method_declaration', '"]"'],
 ['method_declaration', '")"'],
 ['other', '"{"'],
 ['other', '"System"'],
 ['other', '"."'],
 ['other', '"out"'],
 ['other', '"."'],
 ['other', '"println"'],
 ['other', '"("'],
 ['other', '"\'Hello'],
 ['other', '")"'],
 ['other', '";"'],
 ['other', '"}"'],
 ['other', '"}"']]

In [13]:
import json

file_path = "../dataset/valid.txt"
postfix=file_path.split('/')[-1].split('.txt')[0]
index_filename=file_path
url_to_code={}
with open('/'.join(index_filename.split('/')[:-1])+'/data.jsonl') as f:
    for line in f:
        line=line.strip()
        js=json.loads(line)
        url_to_code[js['idx']]=js['func']
data=[]
cache={}
f=open(index_filename)
with open(index_filename) as f:
    lines = 1000
    added_lines = 0
    for line in f:
        # control number of read data 
        if added_lines >= lines:
            break
        line=line.strip()
        url1,url2,label=line.split('\t')
        if url1 not in url_to_code or url2 not in url_to_code:
            continue
        if label=='0':
            label=0
        else:
            label=1
        data.append((url1,url2,label,' '.join(url_to_code[url1].split()), ' '.join(url_to_code[url2].split())))
        added_lines += 1

In [None]:
len(data)

1000

In [21]:
data[2]

('1141235',
 '14322332',
 0,
 'public static void main(String[] args) { String u = "http://portal.acm.org/results.cfm?query=%28Author%3A%22" + "Boehm%2C+Barry" + "%22%29&srt=score%20dsc&short=0&source_disp=&since_month=&since_year=&before_month=&before_year=&coll=ACM&dl=ACM&termshow=matchboolean&range_query=&CFID=22704101&CFTOKEN=37827144&start=1"; URL url = null; AcmSearchresultPageParser_2011Jan cb = new AcmSearchresultPageParser_2011Jan(); try { url = new URL(u); HttpURLConnection uc = (HttpURLConnection) url.openConnection(); uc.setUseCaches(false); InputStream is = uc.getInputStream(); BufferedReader br = new BufferedReader(new InputStreamReader(is)); ParserDelegator pd = new ParserDelegator(); pd.parse(br, cb, true); br.close(); } catch (MalformedURLException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } System.out.println("all doc num= " + cb.getAllDocNum()); for (int i = 0; i < cb.getEachResultStartPositions().size(); i++) { HashMap<String, Integer>

In [43]:
types = get_ast_types(data[5][3])

In [48]:
code = data[5][4]
code = code.replace("{", " {")
code = " ".join(code.split())
code_list = code.split(' ')
tree = parser.parse(bytes(code, "utf8"))
root_node = tree.root_node
declaration = traverse2(root_node)
types = label_tokens(code_list, declaration)

In [49]:
types

['local_variable_declaration',
 'local_variable_declaration',
 'other',
 'other',
 'local_variable_declaration',
 'other',
 'other',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'other',
 'other',
 'local_variable_declaration',
 'local_variable_declaration',
 'other',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'other',
 'other',
 'other',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'local_variable_declaration',
 'other',
 'other',
 'other',
 'other',
 'if_statement',
 'if_statement',
 'other',
 'if_statement',
 'if_statement',
 'if_statement',
 'other',
 'if_statem

In [51]:
data[5][4]

'public Long processAddCompany(Company companyBean, Long holdingId) { PreparedStatement ps = null; DatabaseAdapter dbDyn = null; try { dbDyn = DatabaseAdapter.getInstance(); CustomSequenceType seq = new CustomSequenceType(); seq.setSequenceName("seq_WM_LIST_COMPANY"); seq.setTableName("WM_LIST_COMPANY"); seq.setColumnName("ID_FIRM"); Long sequenceValue = dbDyn.getSequenceNextValue(seq); ps = dbDyn.prepareStatement("insert into WM_LIST_COMPANY (" + " ID_FIRM, " + " full_name, " + " short_name, " + " address, " + " chief, " + " buh, " + " url, " + " short_info, " + " is_deleted" + ")values " + (dbDyn.getIsNeedUpdateBracket() ? "(" : "") + " ?," + " ?," + " ?," + " ?," + " ?," + " ?," + " ?," + " ?," + " 0 " + (dbDyn.getIsNeedUpdateBracket() ? ")" : "")); int num = 1; RsetTools.setLong(ps, num++, sequenceValue); ps.setString(num++, companyBean.getName()); ps.setString(num++, companyBean.getShortName()); ps.setString(num++, companyBean.getAddress()); ps.setString(num++, companyBean.getCeo(