In [1]:
import pandas as pd
import javalang
from javalang.ast import Node
from tqdm import tqdm
from multiprocessing import Process, cpu_count, Manager, Pool 
import os
import torch
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, DataCollatorWithPadding
from anytree import AnyNode
import json
from torch_geometric.data import Data

In [2]:
divide_node_num = 30
MAX_NODE_NUM = 450 # the max num of subgraph, set for zero padding 
max_subgraph_num = int(MAX_NODE_NUM/divide_node_num) 
max_source_length = 400

In [3]:
data_url = '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/data.jsonl'
train_url =  '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/train.txt'
valid_url = '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/test.txt'
test_url = '/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/valid.txt'

In [4]:
data = pd.read_json(path_or_buf=data_url, lines=True)

In [5]:
java_api_url = '/data/code/represent-code-in-human/data/java_api.csv'
java_api = pd.read_csv(java_api_url, header=0, encoding='utf-8')
java_api['index_name'] = java_api['index_name'].apply(str)
java_api 

Unnamed: 0,index_name,index_description,method_description
0,a,Variable in class java.awt.AWTEventMulticaster,
1,A,Static variable in class java.awt.PageAttribut...,"The MediaType instance for Engineering A, 8 1/..."
2,A,Static variable in class javax.print.attribute...,"Specifies the engineering A size, 8.5 inch by ..."
3,A,Static variable in class javax.print.attribute...,A size .
4,A,Static variable in class javax.swing.text.html...,
...,...,...,...
51185,_write(OutputStream),Method in class org.omg.PortableInterceptor.IO...,
51186,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,
51187,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,
51188,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,


In [6]:
def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                yield from expand(item)
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence, api_sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    if token == 'MethodInvocation':
        api = [get_token(child) for child in children if not get_child(child)]
        # api_sequence.append(' '.join(api))
        if len(api) > 1:
            api_sequence.append(api[-1])
    for child in children:
        get_sequence(child, sequence, api_sequence)

In [7]:
def api_match(api_sequence, java_api):
    description_sequence = []
    for api in api_sequence:
        loc = java_api.loc[java_api['index_name'].str.contains(api, case=True)]
        if not loc.empty:
            description = loc['method_description'].iloc[0]
            if description != 'None':
                description_sequence.append(description)
    return description_sequence

In [8]:
def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [9]:
# multi-process
def multi_get_ast_and_des(l, i):
    sequence = []
    api_sequence = []    
    get_sequence(parse_program(data['func'].iloc[i]), sequence, api_sequence)
    ast = ' '.join(sequence)
    api_sequence = list(set(api_sequence)) 
    des = ' '.join(api_match(api_sequence, java_api))
    d = {'ast': ast, 'des': des, 'i': i}
    # print('d', d)
    l.append(d)


manager = Manager()
data_size = len(data)
print('data_size', data_size)

l = manager.list()
p = Pool(processes=20)
for i in range(data_size):
    p.apply_async(multi_get_ast_and_des, (l, i))
p.close()
p.join()

ast = []
des = []
i = []
for d in l[:]:
    ast.append(d['ast'].encode('utf-8','ignore').decode("utf-8"))
    des.append(d['des'].encode('utf-8','ignore').decode("utf-8"))
    i.append(d['i'])
d = {'ast': ast, 'des': des, 'i': i}
df = pd.DataFrame.from_dict(d)    
df

data_size 9126


Unnamed: 0,ast,des,i
0,MethodDeclaration Modifier public ReferenceTyp...,Creates an AlphaComposite object with the spec...,2
1,MethodDeclaration BasicType byte makeIDPFXORMa...,Creates an AlphaComposite object with the spec...,5
2,MethodDeclaration Modifier public Annotation O...,Sets the text for this label to the specified ...,4
3,MethodDeclaration Modifier public static copyF...,Returns the name of the component given the co...,6
4,MethodDeclaration Modifier public Annotation O...,"Retrieves the guarded object, or throws an exc...",7
...,...,...,...
9121,MethodDeclaration Modifier private verifyAvail...,Called by the context acceptor to process a to...,9121
9122,MethodDeclaration Modifier public ReferenceTyp...,Returns the context path of all the endpoints ...,9113
9123,MethodDeclaration search FormalParameter Refer...,Adds a mapping from a single String native to ...,9102
9124,MethodDeclaration Modifier public static copyR...,Returns the behavior when inserting characters...,9120


In [10]:
df = df.sort_values(by=['i']).reset_index(drop=True)
df

Unnamed: 0,ast,des,i
0,MethodDeclaration Modifier public static main ...,Returns the behavior when inserting characters...,0
1,MethodDeclaration Modifier synchronized public...,Creates an AlphaComposite object with the spec...,1
2,MethodDeclaration Modifier public ReferenceTyp...,Creates an AlphaComposite object with the spec...,2
3,MethodDeclaration Modifier public ReferenceTyp...,Adds a mapping from a single String native to ...,3
4,MethodDeclaration Modifier public Annotation O...,Sets the text for this label to the specified ...,4
...,...,...,...
9121,MethodDeclaration Modifier private verifyAvail...,Called by the context acceptor to process a to...,9121
9122,MethodDeclaration Modifier public static copyF...,Returns the behavior when inserting characters...,9122
9123,MethodDeclaration Modifier private ReferenceTy...,Creates an AlphaComposite object with the spec...,9123
9124,MethodDeclaration Modifier public ReferenceTyp...,Creates an AlphaComposite object with the spec...,9124


In [11]:
data['ast'] = df['ast'].to_list()
data['des'] = df['des'].to_list()
data['ast_des'] = data['ast'] + ' ' + data['des']
data

Unnamed: 0,func,idx,ast,des,ast_des
0,public static void main(String[] args) {\n...,10000832,MethodDeclaration Modifier public static main ...,Returns the behavior when inserting characters...,MethodDeclaration Modifier public static main ...
1,public synchronized String getSerialNumber...,10005623,MethodDeclaration Modifier synchronized public...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier synchronized public...
2,public Object run() {\n ...,10005624,MethodDeclaration Modifier public ReferenceTyp...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier public ReferenceTyp...
3,public String post() {\n if (conten...,10005674,MethodDeclaration Modifier public ReferenceTyp...,Adds a mapping from a single String native to ...,MethodDeclaration Modifier public ReferenceTyp...
4,@Override\n public void onCreate(Bundle...,10005879,MethodDeclaration Modifier public Annotation O...,Sets the text for this label to the specified ...,MethodDeclaration Modifier public Annotation O...
...,...,...,...,...,...
9121,private void verifyAvailability() {\n ...,9980885,MethodDeclaration Modifier private verifyAvail...,Called by the context acceptor to process a to...,MethodDeclaration Modifier private verifyAvail...
9122,public static void copyFiles(String strPat...,9983757,MethodDeclaration Modifier public static copyF...,Returns the behavior when inserting characters...,MethodDeclaration Modifier public static copyF...
9123,private String SHA1(String text) throws No...,9983984,MethodDeclaration Modifier private ReferenceTy...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier private ReferenceTy...
9124,public String generateToken(String code) {...,9996334,MethodDeclaration Modifier public ReferenceTyp...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier public ReferenceTyp...


In [12]:
data.to_json(path_or_buf='/data/dataset/CodeXGLUE/Code-Code/Clone-detection-BigCloneBench/dataset/data_enhanced.jsonl',
                     orient='records', lines=True)

In [13]:
data['ast_token'] = data['ast'].str.split()
data['des_token'] = data['des'].str.split()
data['ast_length'] = data['ast_token'].str.len()
data['des_length'] = data['des_token'].str.len()

In [14]:
data['ast_des_length'] = data['ast_length'] + data['des_length']

In [15]:
data.describe()

Unnamed: 0,idx,ast_length,des_length,ast_des_length
count,9126.0,9126.0,9126.0,9126.0
mean,11515330.0,251.20206,96.954197,348.156257
std,7113866.0,359.57828,67.242846,404.390309
min,74.0,9.0,0.0,9.0
25%,5308223.0,101.0,50.0,156.25
50%,11349860.0,168.0,83.0,249.0
75%,17895150.0,273.0,125.0,401.0
max,23677150.0,10978.0,781.0,11423.0


In [16]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = 'None'
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [17]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
ast_tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
roberta = RobertaModel.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
config = RobertaConfig.from_pretrained(checkpoint)
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = ast_tokenizer.add_special_tokens(special_tokens_dict)

In [18]:
#  generate tree for AST Node
def create_tree(root, node, node_list, sub_id_list, leave_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node), get_child(node)

    if children == []:
        # print('this is a leaf:', token, id)
        leave_list.append(id)

    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node, 
    # and the other subtokens will be set as its new children
    token = token.encode('utf-8','ignore').decode("utf-8")   
    sub_token_list = tokenizer.tokenize(token)
    
    if id == 0:
        root.token = sub_token_list[0] # the root node is one of the tokenizer's special tokens
        root.data = node
        # record the num of nodes for every children of root
        root_children_node_num = []
        for child in children:
            node_num = len(node_list)
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=root)
            root_children_node_num.append(len(node_list) - node_num)        
        return root_children_node_num
    else:
        # print(sub_token_list)
        new_node = AnyNode(id=id, token=sub_token_list[0], data=node, parent=parent)
        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list[1:]:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)
        
        for child in children:
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=new_node)
    # print(token, id)

In [19]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    node_index_list.append(tokenizer.convert_tokens_to_ids(token))
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        if node.children: # some chidren are comprised by non-utf8 and will be removed
            variable_token_list.append(node.children[0].token)
            variable_id_list.append(node.children[0].id)   
    
    for child in node.children:
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list)

In [20]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, tokenizer):
    node_list = []
    sub_id_list = [] # record the ids of node that can be divide into multple subtokens
    leave_list = [] # record the ids of leave 
    new_tree = AnyNode(id=0, token=None, data=None)
    root_children_node_num = create_tree(new_tree, ast, node_list, sub_id_list, leave_list, tokenizer)
    # print('root_children_node_num', root_children_node_num)
    x = []
    edge_src = []
    edge_tgt = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, variable_token_list, variable_id_list)

    ast_edge_num = len(edge_src)
    edge_attr = [[0] for _ in range(ast_edge_num)]
    # set subtoken edge type to 2
    for i in range(len(edge_attr)):
        if edge_src[i] in sub_id_list and edge_tgt[i] in sub_id_list:
            edge_attr[i] = [2]
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    dataflow_edge_num = len(edge_src) - ast_edge_num

    # add next-token edge
    nexttoken_edge_num = len(leave_list)-1
    for i in range(nexttoken_edge_num):
        edge_src.append(leave_list[i])
        edge_tgt.append(leave_list[i+1])
        edge_src.append(leave_list[i+1])
        edge_tgt.append(leave_list[i])

    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    
    # set data flow edge type to 3
    for _ in range(nexttoken_edge_num * 2):
        edge_attr.append([3])
    
    return x, edge_index, edge_attr, root_children_node_num

In [21]:

def get_subgraph_node_num(root_children_node_num, divide_node_num):
    subgraph_node_num = []
    node_sum = 0
    real_graph_num = 0
    for num in root_children_node_num:
        node_sum += num
        if node_sum >= divide_node_num:
            subgraph_node_num.append(node_sum)
            node_sum = 0    
    
    subgraph_node_num.append(node_sum)
    real_graph_num = len(subgraph_node_num)

    if real_graph_num >= max_subgraph_num:
        return subgraph_node_num[: max_subgraph_num], max_subgraph_num

    # print(len(subgraph_node_num))
    # if the last subgraph node num < divide_node_num, then put the last subgraph to the second to last subgraph
    # if subgraph_node_num[-1] < divide_node_num:
    #     subgraph_node_num[-2] = subgraph_node_num[-2] + subgraph_node_num[-1]
    #     subgraph_node_num[-1] = 0
    #     real_graph_num -= 1

    # zero padding for tensor transforming
    for _ in range(real_graph_num, max_subgraph_num):
        subgraph_node_num.append(0)
    
    return subgraph_node_num, real_graph_num

In [22]:
x_list = []
edge_index_list = []
edge_attr_list = []
subgraph_node_num_list = []
real_graph_num_list = []

for i in tqdm(range(len(data))):
    ast = parse_program(data['func'][i])
    x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast, ast_tokenizer)
    subgraph_node_num, real_graph_num = get_subgraph_node_num(root_children_node_num, divide_node_num)
    x_list.append(x)
    edge_index_list.append(edge_index)
    edge_attr_list.append(edge_attr)
    subgraph_node_num_list.append(subgraph_node_num)
    real_graph_num_list.append(real_graph_num)

100%|██████████| 9126/9126 [09:06<00:00, 16.69it/s]


In [23]:
data['x'] = x_list
data['edge_index'] = edge_index_list
data['edge_attr'] = edge_attr_list
data['subgraph_node_num'] = subgraph_node_num_list
data['real_graph_num'] = real_graph_num_list

In [None]:
data['x_length'] = data['x'].str.len()
data.describe()

In [24]:
data = data.set_index('idx')
data

Unnamed: 0_level_0,func,ast,des,ast_des,ast_token,des_token,ast_length,des_length,ast_des_length,x,edge_index,edge_attr,subgraph_node_num,real_graph_num
idx,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
10000832,public static void main(String[] args) {\n...,MethodDeclaration Modifier public static main ...,Returns the behavior when inserting characters...,MethodDeclaration Modifier public static main ...,"[MethodDeclaration, Modifier, public, static, ...","[Returns, the, behavior, when, inserting, char...",1825,267,2092,"[50281, 50335, 15110, 42653, 17894, 50289, 502...","[[0, 1, 1, 2, 1, 3, 0, 4, 0, 5, 5, 6, 6, 7, 5,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[2123, 181, 141, 724, 453, 0, 0, 0, 0, 0, 0, 0...",6
10005623,public synchronized String getSerialNumber...,MethodDeclaration Modifier synchronized public...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier synchronized public...,"[MethodDeclaration, Modifier, synchronized, pu...","[Creates, an, AlphaComposite, object, with, th...",152,92,244,"[50281, 50335, 38972, 30630, 1538, 15110, 5027...","[[0, 1, 1, 2, 2, 3, 2, 4, 1, 5, 0, 6, 6, 7, 0,...","[[0], [0], [0], [0], [2], [2], [2], [2], [0], ...","[35, 48, 123, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",4
10005624,public Object run() {\n ...,MethodDeclaration Modifier public ReferenceTyp...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier public ReferenceTyp...,"[MethodDeclaration, Modifier, public, Referenc...","[Creates, an, AlphaComposite, object, with, th...",69,38,107,"[50281, 50335, 15110, 50275, 46674, 2962, 5030...","[[0, 1, 1, 2, 0, 3, 3, 4, 0, 5, 0, 6, 6, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",2
10005674,public String post() {\n if (conten...,MethodDeclaration Modifier public ReferenceTyp...,Adds a mapping from a single String native to ...,MethodDeclaration Modifier public ReferenceTyp...,"[MethodDeclaration, Modifier, public, Referenc...","[Adds, a, mapping, from, a, single, String, na...",395,144,539,"[50281, 50335, 15110, 50275, 34222, 7049, 5029...","[[0, 1, 1, 2, 0, 3, 3, 4, 0, 5, 0, 6, 6, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[37, 574, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",3
10005879,@Override\n public void onCreate(Bundle...,MethodDeclaration Modifier public Annotation O...,Sets the text for this label to the specified ...,MethodDeclaration Modifier public Annotation O...,"[MethodDeclaration, Modifier, public, Annotati...","[Sets, the, text, for, this, label, to, the, s...",143,65,208,"[50281, 50335, 15110, 50278, 49116, 261, 44758...","[[0, 1, 1, 2, 0, 3, 3, 4, 0, 5, 5, 6, 0, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[34, 33, 122, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9980885,private void verifyAvailability() {\n ...,MethodDeclaration Modifier private verifyAvail...,Called by the context acceptor to process a to...,MethodDeclaration Modifier private verifyAvail...,"[MethodDeclaration, Modifier, private, verifyA...","[Called, by, the, context, acceptor, to, proce...",795,156,951,"[50281, 50335, 22891, 2802, 4591, 49054, 50294...","[[0, 1, 1, 2, 0, 3, 3, 4, 3, 5, 0, 6, 6, 7, 7,...","[[0], [0], [0], [0], [0], [0], [2], [2], [2], ...","[1005, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",2
9983757,public static void copyFiles(String strPat...,MethodDeclaration Modifier public static copyF...,Returns the behavior when inserting characters...,MethodDeclaration Modifier public static copyF...,"[MethodDeclaration, Modifier, public, static, ...","[Returns, the, behavior, when, inserting, char...",170,61,231,"[50281, 50335, 15110, 42653, 44273, 14824, 502...","[[0, 1, 1, 2, 1, 3, 0, 4, 4, 5, 0, 6, 6, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [2], ...","[41, 166, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",3
9983984,private String SHA1(String text) throws No...,MethodDeclaration Modifier private ReferenceTy...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier private ReferenceTy...,"[MethodDeclaration, Modifier, private, Referen...","[Creates, an, AlphaComposite, object, with, th...",64,55,119,"[50281, 50335, 22891, 50275, 34222, 45004, 134...","[[0, 1, 1, 2, 0, 3, 3, 4, 0, 5, 5, 6, 0, 7, 7,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[44, 34, 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",3
9996334,public String generateToken(String code) {...,MethodDeclaration Modifier public ReferenceTyp...,Creates an AlphaComposite object with the spec...,MethodDeclaration Modifier public ReferenceTyp...,"[MethodDeclaration, Modifier, public, Referenc...","[Creates, an, AlphaComposite, object, with, th...",52,48,100,"[50281, 50335, 15110, 50275, 34222, 20557, 877...","[[0, 1, 1, 2, 0, 3, 3, 4, 0, 5, 5, 6, 5, 7, 0,...","[[0], [0], [0], [0], [0], [0], [0], [0], [0], ...","[73, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",2


In [25]:
def read_ccd_pairs(url):
    data = []
    with open(url) as f:
        for line in f:
            line = line.strip()
            id1, id2, label = line.split('\t')
            label = 0 if label == '0' else 1
            data.append((int(id1), int(id2), label))
    return data

train_pairs = read_ccd_pairs(train_url)
valid_pairs = read_ccd_pairs(valid_url)
test_pairs = read_ccd_pairs(test_url)

In [26]:
class PairData(Data):
    def __init__(self, edge_index_s, edge_attr_s, x_s, source_ids_s, subgraph_node_num_s, real_graph_num_s,
                    edge_index_t, edge_attr_t, x_t, source_ids_t, subgraph_node_num_t, real_graph_num_t,label):
        super(PairData, self).__init__()
        self.edge_index_s = edge_index_s
        self.edge_attr_s = edge_attr_s
        self.x_s = x_s
        self.source_ids_s = source_ids_s
        self.subgraph_node_num_s = subgraph_node_num_s
        self.real_graph_num_s = real_graph_num_s

        self.edge_index_t = edge_index_t
        self.edge_attr_t = edge_attr_t
        self.x_t = x_t
        self.source_ids_t = source_ids_t
        self.subgraph_node_num_t = subgraph_node_num_t
        self.real_graph_num_t = real_graph_num_t

        self.label = label
    
    def __inc__(self, key, value):
        if key == 'edge_index_s':
            return self.x_s.size(0)
        if key == 'edge_index_t':
            return self.x_t.size(0)
        else:
            return super().__inc__(key, value)
    

In [27]:
def convert_examples_to_features(examples, tokenizer, data):
    features = []
    for example in tqdm(examples):        
        id1 = example[0]
        id2 = example[1]
        label = example[2]

        x1 = data['x'][id1]
        edge_index1 = data['edge_index'][id1]
        edge_attr1 = data['edge_attr'][id1]
        subgraph_node_num1 = data['subgraph_node_num'][id1]
        real_graph_num1 = data['real_graph_num'][id1]
        ast_des1 = tokenizer.tokenize(data['ast_des'][id1])[: max_source_length-2]
        ast_des1 = [tokenizer.cls_token] + ast_des1 + [tokenizer.sep_token]
        source_ids1 = tokenizer.convert_tokens_to_ids(ast_des1)
        padding_length = max_source_length - len(source_ids1)
        source_ids1 = source_ids1 + [tokenizer.pad_token_id] * padding_length

        
        x2 = data['x'][id2]
        edge_index2 = data['edge_index'][id2]
        edge_attr2 = data['edge_attr'][id2]
        subgraph_node_num2 = data['subgraph_node_num'][id2]
        real_graph_num2 = data['real_graph_num'][id2]
        ast_des2 = tokenizer.tokenize(data['ast_des'][id2])[: max_source_length-2]
        ast_des2 = [tokenizer.cls_token] + ast_des2 + [tokenizer.sep_token]
        source_ids2 = tokenizer.convert_tokens_to_ids(ast_des2)
        padding_length = max_source_length - len(source_ids2)
        source_ids2 = source_ids2 + [tokenizer.pad_token_id] * padding_length

        if data['ast_des_length'][id1] < 600 and data['ast_des_length'][id2] < 600:
            features.append(
                PairData(
                    x_s= torch.tensor(x1, dtype=torch.long),
                    edge_index_s=torch.tensor(edge_index1, dtype=torch.long),
                    edge_attr_s=torch.tensor(edge_attr1, dtype=torch.long),
                    source_ids_s=torch.tensor(source_ids1, dtype=torch.long),
                    subgraph_node_num_s=torch.tensor(subgraph_node_num1, dtype=torch.long),
                    real_graph_num_s=torch.tensor(real_graph_num1, dtype=torch.long), 

                    x_t= torch.tensor(x2, dtype=torch.long),
                    edge_index_t=torch.tensor(edge_index2, dtype=torch.long),
                    edge_attr_t=torch.tensor(edge_attr2, dtype=torch.long),
                    source_ids_t=torch.tensor(source_ids2, dtype=torch.long),
                    subgraph_node_num_t=torch.tensor(subgraph_node_num2, dtype=torch.long),
                    real_graph_num_t=torch.tensor(real_graph_num2, dtype=torch.long), 

                    label = torch.tensor(label, dtype=torch.long)
                )
            )
    return features

In [28]:
train_features = convert_examples_to_features(train_pairs, tokenizer, data)
valid_features = convert_examples_to_features(valid_pairs, tokenizer, data)
test_features = convert_examples_to_features(test_pairs, tokenizer, data)

100%|██████████| 901028/901028 [1:14:04<00:00, 202.74it/s]
100%|██████████| 415416/415416 [30:57<00:00, 223.65it/s]
100%|██████████| 415416/415416 [32:41<00:00, 211.74it/s]


In [29]:
torch.save(train_features,'features/bcb-raw/train_features.pt')
torch.save(valid_features,'features/bcb-raw/valid_features.pt')
torch.save(test_features,'features/bcb-raw/test_features.pt')

In [30]:
len(train_features)

834647

In [31]:
len(valid_features)

390286

In [32]:
len(test_features)

382375