In [1]:
import pandas as pd
import javalang
from javalang.ast import Node
from tqdm import tqdm
from multiprocessing import Process, cpu_count, Manager, Pool 
import os
import torch
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, DataCollatorWithPadding
from anytree import AnyNode
import json
from torch_geometric.data import Data

In [2]:
divide_node_num = 30
MAX_NODE_NUM = 450 # the max num of subgraph, set for zero padding 
max_subgraph_num = int(MAX_NODE_NUM/divide_node_num) 
max_source_length = 400

In [3]:
data_url = '/data/pycharm/BCB2015/BCB_F/function_out_split/data.jsonl'
train_url = '/data/pycharm/BCB2015/BCB_F/function_out_split/train.txt'
valid_url = '/data/pycharm/BCB2015/BCB_F/function_out_split/test.txt'
test_url = '/data/pycharm/BCB2015/BCB_F/function_out_split/valid.txt'

In [4]:
data = pd.read_json(path_or_buf=data_url, lines=True)

In [5]:
java_api_url = '/data/code/represent-code-in-human/data/java_api.csv'
java_api = pd.read_csv(java_api_url, header=0, encoding='utf-8')
java_api['index_name'] = java_api['index_name'].apply(str)
java_api 

Unnamed: 0,index_name,index_description,method_description
0,a,Variable in class java.awt.AWTEventMulticaster,
1,A,Static variable in class java.awt.PageAttribut...,"The MediaType instance for Engineering A, 8 1/..."
2,A,Static variable in class javax.print.attribute...,"Specifies the engineering A size, 8.5 inch by ..."
3,A,Static variable in class javax.print.attribute...,A size .
4,A,Static variable in class javax.swing.text.html...,
...,...,...,...
51185,_write(OutputStream),Method in class org.omg.PortableInterceptor.IO...,
51186,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,
51187,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,
51188,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,


In [6]:
def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                yield from expand(item)
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence, api_sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    if token == 'MethodInvocation':
        api = [get_token(child) for child in children if not get_child(child)]
        # api_sequence.append(' '.join(api))
        if len(api) > 1:
            api_sequence.append(api[-1])
    for child in children:
        get_sequence(child, sequence, api_sequence)

In [7]:
def api_match(api_sequence, java_api):
    description_sequence = []
    for api in api_sequence:
        loc = java_api.loc[java_api['index_name'].str.contains(api, case=True)]
        if not loc.empty:
            description = loc['method_description'].iloc[0]
            if description != 'None':
                description_sequence.append(description)
    return description_sequence

In [8]:
def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [9]:
# multi-process
def multi_get_ast_and_des(l, i):
    # print('run2')
    sequence = []
    api_sequence = []    
    # print('sequence', sequence)
    get_sequence(parse_program(data['func'].iloc[i]), sequence, api_sequence)
    # print('sequence', sequence)
    ast = ' '.join(sequence)
    api_sequence = list(set(api_sequence)) 
    des = ' '.join(api_match(api_sequence, java_api)) 
    d = {'ast': ast, 'des': des, 'i': i}
    # print('d', d)
    l.append(d)


manager = Manager()
data_size = len(data)
# print('data_size', data_size)
l = manager.list()
p = Pool(processes=30)
for i in range(data_size):
    # print('run1')
    p.apply_async(multi_get_ast_and_des, (l, i))
p.close()
p.join()

ast = []
des = []
i = []
for d in l[:]:
    # print('i', i)
    ast.append(d['ast'].encode('utf-8','ignore').decode("utf-8"))
    des.append(d['des'].encode('utf-8','ignore').decode("utf-8"))
    i.append(d['i'])
d = {'ast': ast, 'des': des, 'i': i}
df = pd.DataFrame.from_dict(d)    
df

Unnamed: 0,ast,des,i
0,ConstructorDeclaration ResultHelperTests Excep...,Creates a new instance of a DocumentBuilder us...,1
1,MethodDeclaration Modifier static copy FormalP...,Returns the behavior when inserting characters...,0
2,MethodDeclaration Modifier public shuttlesort ...,,8
3,MethodDeclaration Modifier public convert Form...,Terminates the current line by writing the lin...,3
4,MethodDeclaration Modifier public ReferenceTyp...,Returns the model that this button represents....,2
...,...,...,...
73296,MethodDeclaration Modifier public static main ...,Sets the subject criterion. Creates an AlphaCo...,73282
73297,MethodDeclaration Modifier public static Refer...,Returns the hardware address (usually MAC) of ...,73291
73298,MethodDeclaration Modifier public static execu...,"Causes the current thread to wait, if necessar...",73284
73299,MethodDeclaration Modifier private okButtonAct...,If passed to the appropriate variant of java....,73196


In [10]:
df = df.sort_values(by=['i']).reset_index(drop=True)
df

Unnamed: 0,ast,des,i
0,MethodDeclaration Modifier static copy FormalP...,Returns the behavior when inserting characters...,0
1,ConstructorDeclaration ResultHelperTests Excep...,Creates a new instance of a DocumentBuilder us...,1
2,MethodDeclaration Modifier public ReferenceTyp...,Returns the model that this button represents....,2
3,MethodDeclaration Modifier public convert Form...,Terminates the current line by writing the lin...,3
4,MethodDeclaration Modifier public static main ...,Extract the appropriate property value from th...,4
...,...,...,...
73296,MethodDeclaration Modifier public static Refer...,Compiles the given regular expression into a p...,73296
73297,MethodDeclaration Modifier public static copyD...,Adds the RoleUnresolved specified as the last ...,73297
73298,MethodDeclaration Modifier public static copyD...,Adds the RoleUnresolved specified as the last ...,73298
73299,MethodDeclaration Modifier public static Basic...,Returns the current length of the sequence. Re...,73299


In [11]:
data['ast'] = df['ast'].to_list()
data['des'] = df['des'].to_list()
data['ast_des'] = data['ast'] + ' ' + data['des']
data

Unnamed: 0,idx,func,ast,des,ast_des
0,74,"static void copy (String src, String dest) thr...",MethodDeclaration Modifier static copy FormalP...,Returns the behavior when inserting characters...,MethodDeclaration Modifier static copy FormalP...
1,525,ResultHelperTests () throws Exception {\n N...,ConstructorDeclaration ResultHelperTests Excep...,Creates a new instance of a DocumentBuilder us...,ConstructorDeclaration ResultHelperTests Excep...
2,587,"public HTMLDocument handleURL (String suburl, ...",MethodDeclaration Modifier public ReferenceTyp...,Returns the model that this button represents....,MethodDeclaration Modifier public ReferenceTyp...
3,661,"public void convert (File src, File dest) thro...",MethodDeclaration Modifier public convert Form...,Terminates the current line by writing the lin...,MethodDeclaration Modifier public convert Form...
4,778,public static void main (String [] args) {\n ...,MethodDeclaration Modifier public static main ...,Extract the appropriate property value from th...,MethodDeclaration Modifier public static main ...
...,...,...,...,...,...
73296,23677221,public static List < String > extractMatches (...,MethodDeclaration Modifier public static Refer...,Compiles the given regular expression into a p...,MethodDeclaration Modifier public static Refer...
73297,23677222,"public static void copyDirectory1 (Path src, P...",MethodDeclaration Modifier public static copyD...,Adds the RoleUnresolved specified as the last ...,MethodDeclaration Modifier public static copyD...
73298,23677223,public static void copyDirectory2 (final Path ...,MethodDeclaration Modifier public static copyD...,Adds the RoleUnresolved specified as the last ...,MethodDeclaration Modifier public static copyD...
73299,23677226,public static boolean isPalindrome (String ori...,MethodDeclaration Modifier public static Basic...,Returns the current length of the sequence. Re...,MethodDeclaration Modifier public static Basic...


In [12]:
data.to_json(path_or_buf='/data/pycharm/BCB2015/BCB_F/function_out_split/data_enhanced.jsonl',
                     orient='records', lines=True)

In [13]:
data['ast_token'] = data['ast'].str.split()
data['des_token'] = data['des'].str.split()
data['ast_length'] = data['ast_token'].str.len()
data['des_length'] = data['des_token'].str.len()

In [14]:
data['ast_des_length'] = data['ast_length'] + data['des_length']

In [15]:
data.describe()

Unnamed: 0,idx,ast_length,des_length,ast_des_length
count,73301.0,73301.0,73301.0,73301.0
mean,10922570.0,233.262288,64.87404,298.136328
std,7158256.0,418.806626,66.196633,455.417573
min,74.0,4.0,0.0,4.0
25%,4428407.0,68.0,17.0,97.0
50%,10704550.0,133.0,48.0,188.0
75%,17116120.0,254.0,92.0,347.0
max,23677230.0,19230.0,1233.0,19251.0


In [16]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = 'None'
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [17]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
ast_tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
roberta = RobertaModel.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
config = RobertaConfig.from_pretrained(checkpoint)
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = ast_tokenizer.add_special_tokens(special_tokens_dict)

In [18]:
#  generate tree for AST Node
def create_tree(root, node, node_list, sub_id_list, leave_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node), get_child(node)

    if children == []:
        # print('this is a leaf:', token, id)
        leave_list.append(id)

    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node, 
    # and the other subtokens will be set as its new children
    token = token.encode('utf-8','ignore').decode("utf-8")   
    sub_token_list = tokenizer.tokenize(token)
    
    if id == 0:
        root.token = sub_token_list[0] # the root node is one of the tokenizer's special tokens
        root.data = node
        # record the num of nodes for every children of root
        root_children_node_num = []
        for child in children:
            node_num = len(node_list)
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=root)
            root_children_node_num.append(len(node_list) - node_num)        
        return root_children_node_num
    else:
        # print(sub_token_list)
        new_node = AnyNode(id=id, token=sub_token_list[0], data=node, parent=parent)
        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list[1:]:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)
        
        for child in children:
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=new_node)
    # print(token, id)

In [19]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    node_index_list.append(tokenizer.convert_tokens_to_ids(token))
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        if node.children: # some chidren are comprised by non-utf8 and will be removed
            variable_token_list.append(node.children[0].token)
            variable_id_list.append(node.children[0].id)   
    
    for child in node.children:
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list)

In [20]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, tokenizer):
    node_list = []
    sub_id_list = [] # record the ids of node that can be divide into multple subtokens
    leave_list = [] # record the ids of leave 
    new_tree = AnyNode(id=0, token=None, data=None)
    root_children_node_num = create_tree(new_tree, ast, node_list, sub_id_list, leave_list, tokenizer)
    # print('root_children_node_num', root_children_node_num)
    x = []
    edge_src = []
    edge_tgt = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, variable_token_list, variable_id_list)

    ast_edge_num = len(edge_src)
    edge_attr = [[0] for _ in range(ast_edge_num)]
    # set subtoken edge type to 2
    for i in range(len(edge_attr)):
        if edge_src[i] in sub_id_list and edge_tgt[i] in sub_id_list:
            edge_attr[i] = [2]
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    dataflow_edge_num = len(edge_src) - ast_edge_num

    # add next-token edge
    nexttoken_edge_num = len(leave_list)-1
    for i in range(nexttoken_edge_num):
        edge_src.append(leave_list[i])
        edge_tgt.append(leave_list[i+1])
        edge_src.append(leave_list[i+1])
        edge_tgt.append(leave_list[i])

    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    
    # set data flow edge type to 3
    for _ in range(nexttoken_edge_num * 2):
        edge_attr.append([3])
    
    return x, edge_index, edge_attr, root_children_node_num

In [21]:

def get_subgraph_node_num(root_children_node_num, divide_node_num):
    subgraph_node_num = []
    node_sum = 0
    real_graph_num = 0
    for num in root_children_node_num:
        node_sum += num
        if node_sum >= divide_node_num:
            subgraph_node_num.append(node_sum)
            node_sum = 0    
    
    subgraph_node_num.append(node_sum)
    real_graph_num = len(subgraph_node_num)

    if real_graph_num >= max_subgraph_num:
        return subgraph_node_num[: max_subgraph_num], max_subgraph_num

    # print(len(subgraph_node_num))
    # if the last subgraph node num < divide_node_num, then put the last subgraph to the second to last subgraph
    # if subgraph_node_num[-1] < divide_node_num:
    #     subgraph_node_num[-2] = subgraph_node_num[-2] + subgraph_node_num[-1]
    #     subgraph_node_num[-1] = 0
    #     real_graph_num -= 1

    # zero padding for tensor transforming
    for _ in range(real_graph_num, max_subgraph_num):
        subgraph_node_num.append(0)
    
    return subgraph_node_num, real_graph_num

In [22]:
x_list = []
edge_index_list = []
edge_attr_list = []
subgraph_node_num_list = []
real_graph_num_list = []

for i in tqdm(range(len(data))):
    ast = parse_program(data['func'][i])
    x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast, ast_tokenizer)
    subgraph_node_num, real_graph_num = get_subgraph_node_num(root_children_node_num, divide_node_num)
    x_list.append(x)
    edge_index_list.append(edge_index)
    edge_attr_list.append(edge_attr)
    subgraph_node_num_list.append(subgraph_node_num)
    real_graph_num_list.append(real_graph_num)

100%|██████████| 73301/73301 [1:12:47<00:00, 16.78it/s]


In [23]:
data['x'] = x_list
data['edge_index'] = edge_index_list
data['edge_attr'] = edge_attr_list
data['subgraph_node_num'] = subgraph_node_num_list
data['real_graph_num'] = real_graph_num_list

In [None]:
data['x_length'] = data['x'].str.len()
data.describe()

In [24]:
data = data.set_index('idx')
data

Unnamed: 0_level_0,func,ast,des,ast_des,ast_token,des_token,ast_length,des_length,ast_des_length,x,edge_index,edge_attr,subgraph_node_num,real_graph_num
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
74,"static void copy (String src, String dest) thr...",MethodDeclaration Modifier static copy FormalP...,Returns the behavior when inserting characters...,MethodDeclaration Modifier static copy FormalP...,"[MethodDeclaration, Modifier, static, copy, Fo...","[Returns, the, behavior, when, inserting, char...",127,32,159,"[50281, 50335, 42653, 44273, 50289, 50275, 342...","[[0, 1, 1, 2, 0, 3, 0, 4, 4, 5, 5, 6, 4, 7, 0,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[35, 33, 32, 44, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0...",5
525,ResultHelperTests () throws Exception {\n N...,ConstructorDeclaration ResultHelperTests Excep...,Creates a new instance of a DocumentBuilder us...,ConstructorDeclaration ResultHelperTests Excep...,"[ConstructorDeclaration, ResultHelperTests, Ex...","[Creates, a, new, instance, of, a, DocumentBui...",511,127,638,"[50283, 48136, 48825, 565, 10092, 48847, 50287...","[[0, 1, 1, 2, 1, 3, 1, 4, 0, 5, 0, 6, 6, 7, 7,...","[[0], [0], [2], [2], [2], [2], [2], [2], [0], ...","[40, 43, 42, 30, 46, 44, 43, 51, 40, 53, 54, 5...",15
587,"public HTMLDocument handleURL (String suburl, ...",MethodDeclaration Modifier public ReferenceTyp...,Returns the model that this button represents....,MethodDeclaration Modifier public ReferenceTyp...,"[MethodDeclaration, Modifier, public, Referenc...","[Returns, the, model, that, this, button, repr...",248,83,331,"[50281, 50335, 15110, 50275, 48085, 47088, 266...","[[0, 1, 1, 2, 0, 3, 3, 4, 4, 5, 0, 6, 6, 7, 0,...","[[0], [0], [0], [0], [0], [0], [0], [0], [2], ...","[39, 33, 134, 34, 30, 40, 44, 4, 0, 0, 0, 0, 0...",8
661,"public void convert (File src, File dest) thro...",MethodDeclaration Modifier public convert Form...,Terminates the current line by writing the lin...,MethodDeclaration Modifier public convert Form...,"[MethodDeclaration, Modifier, public, convert,...","[Terminates, the, current, line, by, writing, ...",686,83,769,"[50281, 50335, 15110, 3865, 9942, 50289, 50275...","[[0, 1, 1, 2, 0, 3, 3, 4, 0, 5, 5, 6, 6, 7, 5,...","[[0], [0], [0], [0], [0], [0], [2], [2], [0], ...","[33, 42, 848, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",4
778,public static void main (String [] args) {\n ...,MethodDeclaration Modifier public static main ...,Extract the appropriate property value from th...,MethodDeclaration Modifier public static main ...,"[MethodDeclaration, Modifier, public, static, ...","[Extract, the, appropriate, property, value, f...",112,48,160,"[50281, 50335, 15110, 42653, 17894, 50289, 502...","[[0, 1, 1, 2, 1, 3, 0, 4, 0, 5, 5, 6, 6, 7, 5,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[140, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23677221,public static List < String > extractMatches (...,MethodDeclaration Modifier public static Refer...,Compiles the given regular expression into a p...,MethodDeclaration Modifier public static Refer...,"[MethodDeclaration, Modifier, public, static, ...","[Compiles, the, given, regular, expression, in...",75,75,150,"[50281, 50335, 15110, 42653, 50275, 36583, 502...","[[0, 1, 1, 2, 1, 3, 0, 4, 4, 5, 4, 6, 6, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[37, 47, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",3
23677222,"public static void copyDirectory1 (Path src, P...",MethodDeclaration Modifier public static copyD...,Adds the RoleUnresolved specified as the last ...,MethodDeclaration Modifier public static copyD...,"[MethodDeclaration, Modifier, public, static, ...","[Adds, the, RoleUnresolved, specified, as, the...",69,51,120,"[50281, 50335, 15110, 42653, 44273, 49226, 134...","[[0, 1, 1, 2, 1, 3, 0, 4, 4, 5, 4, 6, 0, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [2], ...","[81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",2
23677223,public static void copyDirectory2 (final Path ...,MethodDeclaration Modifier public static copyD...,Adds the RoleUnresolved specified as the last ...,MethodDeclaration Modifier public static copyD...,"[MethodDeclaration, Modifier, public, static, ...","[Adds, the, RoleUnresolved, specified, as, the...",100,43,143,"[50281, 50335, 15110, 42653, 44273, 49226, 176...","[[0, 1, 1, 2, 1, 3, 0, 4, 4, 5, 4, 6, 0, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [2], ...","[139, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",2
23677226,public static boolean isPalindrome (String ori...,MethodDeclaration Modifier public static Basic...,Returns the current length of the sequence. Re...,MethodDeclaration Modifier public static Basic...,"[MethodDeclaration, Modifier, public, static, ...","[Returns, the, current, length, of, the, seque...",74,25,99,"[50281, 50335, 15110, 42653, 50274, 3983, 4854...","[[0, 1, 1, 2, 1, 3, 0, 4, 4, 5, 5, 6, 0, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[67, 13, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",2


In [25]:
def read_ccd_pairs(url):
    data = []
    with open(url) as f:
        for line in f:
            line = line.strip()
            id1, id2, label = line.split('\t')
            label = 0 if label == '0' else 1
            data.append((int(id1), int(id2), label))
    return data

train_pairs = read_ccd_pairs(train_url)
valid_pairs = read_ccd_pairs(valid_url)
test_pairs = read_ccd_pairs(test_url)

In [26]:
class PairData(Data):
    def __init__(self, edge_index_s, edge_attr_s, x_s, source_ids_s, subgraph_node_num_s, real_graph_num_s,
                    edge_index_t, edge_attr_t, x_t, source_ids_t, subgraph_node_num_t, real_graph_num_t,label):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.edge_attr_s = edge_attr_s
        self.x_s = x_s
        self.source_ids_s = source_ids_s
        self.subgraph_node_num_s = subgraph_node_num_s
        self.real_graph_num_s = real_graph_num_s

        self.edge_index_t = edge_index_t
        self.edge_attr_t = edge_attr_t
        self.x_t = x_t
        self.source_ids_t = source_ids_t
        self.subgraph_node_num_t = subgraph_node_num_t
        self.real_graph_num_t = real_graph_num_t

        self.label = label
    
    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value)
    

In [27]:
def convert_examples_to_features(examples, tokenizer, data):
    features = []
    for example in tqdm(examples):        
        id1 = example[0]
        id2 = example[1]
        label = example[2]

        x1 = data['x'][id1]
        edge_index1 = data['edge_index'][id1]
        edge_attr1 = data['edge_attr'][id1]
        subgraph_node_num1 = data['subgraph_node_num'][id1]
        real_graph_num1 = data['real_graph_num'][id1]
        ast_des1 = tokenizer.tokenize(data['ast_des'][id1])[: max_source_length-2]
        ast_des1 = [tokenizer.cls_token] + ast_des1 + [tokenizer.sep_token]
        source_ids1 = tokenizer.convert_tokens_to_ids(ast_des1)
        padding_length = max_source_length - len(source_ids1)
        source_ids1 = source_ids1 + [tokenizer.pad_token_id] * padding_length

        
        x2 = data['x'][id2]
        edge_index2 = data['edge_index'][id2]
        edge_attr2 = data['edge_attr'][id2]
        subgraph_node_num2 = data['subgraph_node_num'][id2]
        real_graph_num2 = data['real_graph_num'][id2]
        ast_des2 = tokenizer.tokenize(data['ast_des'][id2])[: max_source_length-2]
        ast_des2 = [tokenizer.cls_token] + ast_des2 + [tokenizer.sep_token]
        source_ids2 = tokenizer.convert_tokens_to_ids(ast_des2)
        padding_length = max_source_length - len(source_ids2)
        source_ids2 = source_ids2 + [tokenizer.pad_token_id] * padding_length

        if data['ast_des_length'][id1] < 600 and data['ast_des_length'][id2] < 600:
            features.append(
                PairData(
                    x_s= torch.tensor(x1, dtype=torch.long),
                    edge_index_s=torch.tensor(edge_index1, dtype=torch.long),
                    edge_attr_s=torch.tensor(edge_attr1, dtype=torch.long),
                    source_ids_s=torch.tensor(source_ids1, dtype=torch.long),
                    subgraph_node_num_s=torch.tensor(subgraph_node_num1, dtype=torch.long),
                    real_graph_num_s=torch.tensor(real_graph_num1, dtype=torch.long), 

                    x_t= torch.tensor(x2, dtype=torch.long),
                    edge_index_t=torch.tensor(edge_index2, dtype=torch.long),
                    edge_attr_t=torch.tensor(edge_attr2, dtype=torch.long),
                    source_ids_t=torch.tensor(source_ids2, dtype=torch.long),
                    subgraph_node_num_t=torch.tensor(subgraph_node_num2, dtype=torch.long),
                    real_graph_num_t=torch.tensor(real_graph_num2, dtype=torch.long), 

                    label = torch.tensor(label, dtype=torch.long)
                )
            )
    return features

In [28]:
train_features = convert_examples_to_features(train_pairs, tokenizer, data)
valid_features = convert_examples_to_features(valid_pairs, tokenizer, data)
test_features = convert_examples_to_features(test_pairs, tokenizer, data)

100%|██████████| 394668/394668 [24:16<00:00, 270.88it/s]
100%|██████████| 82316/82316 [06:10<00:00, 221.98it/s]
100%|██████████| 80930/80930 [05:27<00:00, 247.16it/s]


In [29]:
# torch.save(train_features,'features/bcb-f/train_features.pt')
# torch.save(valid_features,'features/bcb-f/valid_features.pt')
# torch.save(test_features,'features/bcb-f/test_features.pt')

In [30]:
len(train_features)

372048

In [31]:
len(valid_features)

76928

In [32]:
len(test_features)

74425