In [1]:
import pandas as pd
import javalang
from javalang.ast import Node
from tqdm import tqdm
from multiprocessing import Process, cpu_count, Manager, Pool 
import os
import torch
from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, DataCollatorWithPadding
from anytree import AnyNode
import json
from torch_geometric.data import Data

In [2]:
divide_node_num = 30
MAX_NODE_NUM = 300 # the max num of subgraph, set for zero padding 
max_subgraph_num = int(MAX_NODE_NUM/divide_node_num) 
max_source_length = 256
max_target_length = 32

In [3]:
train_url = '/data/code/represent-code-in-human/data/TLC-SUM/train.json'
valid_url = '/data/code/represent-code-in-human/data/TLC-SUM/valid.json'
test_url = '/data/code/represent-code-in-human/data/TLC-SUM/test.json'

In [4]:
train_data = pd.read_json(path_or_buf=train_url, lines=True)
valid_data = pd.read_json(path_or_buf=valid_url, lines=True)
test_data = pd.read_json(path_or_buf=test_url, lines=True)

In [5]:
train_data['docstring_tokens'] = train_data['comment'].str.split()
valid_data['docstring_tokens'] = valid_data['comment'].str.split()
test_data['docstring_tokens'] = test_data['comment'].str.split()

In [6]:
train_data

Unnamed: 0,api_seq,comment,code,id,docstring_tokens
0,[Arrays.asList],Runs a command on the command line synchronou...,@Override public int runCommand(boolean mergeE...,4270,"[Runs, a, command, on, the, command, line, syn..."
1,"[String.length, String.length, String.startsWi...",Find Price List Version and update context,private int findPLV(int M_PriceList_ID){\n Ti...,65622,"[Find, Price, List, Version, and, update, cont..."
2,[Runtime.totalMemory],Returns true if less then 5% of the available...,public static boolean memoryIsLow(){\n return...,27884,"[Returns, true, if, less, then, 5%, of, the, a..."
3,"[StringBuilder.append, StringBuilder.append, S...",Returns a string representation of the object...,public String describeAttributes(){\n StringB...,53496,"[Returns, a, string, representation, of, the, ..."
4,[SecureRandom.nextBytes],Fill the given buffer with random bytes.,public static byte[] nextBytes(byte[] buffer){...,12260,"[Fill, the, given, buffer, with, random, bytes.]"
...,...,...,...,...,...
69703,[List.add],Generic Test SOAP Service,"public static Map<String,Object> testSOAPServi...",79340,"[Generic, Test, SOAP, Service]"
69704,[ResultSet.next],Writes the entire ResultSet to a CSV file. Th...,"public void writeAll(ResultSet rs,boolean incl...",85763,"[Writes, the, entire, ResultSet, to, a, CSV, f..."
69705,[System.setProperty],"Workaround for bug pre-Froyo, see here for mo...",public static void disableConnectionReuseIfNec...,51974,"[Workaround, for, bug, pre-Froyo,, see, here, ..."
69706,"[JComboBox.getSource, Integer.parseInt, JTextF...","When faction or subfaction is changed, refres...",private void updateRatingChoice(){\n int curr...,70363,"[When, faction, or, subfaction, is, changed,, ..."


In [7]:
len(valid_data)

8714

In [8]:
len(test_data)

8714

Delete the codes that cannot be parsed by javalang

In [9]:
def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [10]:
def get_syntax_error_ids(data):
    syntax_error_ids = []
    for i in tqdm(range(len(data['code']))):
        try:
            tree = parse_program(data['code'][i])
        except:
            syntax_error_ids.append(i)
    return syntax_error_ids

In [11]:
train_syntax_error_ids = get_syntax_error_ids(train_data)
len(train_syntax_error_ids)

 61%|██████    | 42432/69708 [01:09<00:41, 655.15it/s]

In [None]:
valid_syntax_error_ids = get_syntax_error_ids(valid_data)
len(valid_syntax_error_ids)

100%|██████████| 8714/8714 [00:15<00:00, 552.14it/s]


0

In [None]:
test_syntax_error_ids = get_syntax_error_ids(test_data)
len(test_syntax_error_ids)

100%|██████████| 8714/8714 [00:15<00:00, 580.59it/s]


0

In [None]:
train_data_new = train_data.drop(train_syntax_error_ids)
valid_data_new = valid_data.drop(valid_syntax_error_ids)
test_data_new = test_data.drop(test_syntax_error_ids)

In [None]:
train_data_new = train_data_new.sample(frac=1).reset_index(drop=True)
valid_data_new = valid_data_new.sample(frac=1).reset_index(drop=True)
test_data_new = test_data_new.sample(frac=1).reset_index(drop=True)

In [None]:
train_data_new

Unnamed: 0,api_seq,comment,code,id,docstring_tokens
0,"[Iterator.hasNext, Pattern.matches, Iterator.n...",checks whether a specific value is container ...,public boolean containsValue(Object value){\n ...,41427,"[checks, whether, a, specific, value, is, cont..."
1,"[String.replaceAll, Pattern.compile, Pattern.c...",Convert a character literal into a character.,protected static Character toChar(String value...,40344,"[Convert, a, character, literal, into, a, char..."
2,[Math.log],compute the Shannon-Weaver diversity index in...,public static String computeShannonWeaver(View...,20715,"[compute, the, Shannon-Weaver, diversity, inde..."
3,[AtomicBoolean.set],This function should always be called under a...,private int reconcilePutPermits(){\n putPermi...,23760,"[This, function, should, always, be, called, u..."
4,"[Math.min, Math.max]",Updates the values range.,private void updateRange(double value){\n mMi...,81925,"[Updates, the, values, range.]"
...,...,...,...,...,...
69703,[Locale.getDefault],Return the standard presentation of this diag...,@Override public String toString(){\n return ...,84370,"[Return, the, standard, presentation, of, this..."
69704,[String.getTitle],Shares an episode to the Android app of choice,public static void shareEpisode(Context contex...,70005,"[Shares, an, episode, to, the, Android, app, o..."
69705,"[List.get, BigDecimal.getBasePrice, BigDecimal...",Add an item to the shopping cart.,"public int addItemToEnd(String productId,BigDe...",26198,"[Add, an, item, to, the, shopping, cart.]"
69706,[EnumSet.allOf],Find the _Fields constant that matches fieldI...,public static _Fields findByThriftId(int field...,78264,"[Find, the, _Fields, constant, that, matches, ..."


enhance code by ast and description

In [None]:
java_api_url = '/data/code/represent-code-in-human/data/java_api.csv'
java_api = pd.read_csv(java_api_url, header=0, encoding='utf-8')
java_api['index_name'] = java_api['index_name'].apply(str)
java_api 

Unnamed: 0,index_name,index_description,method_description
0,a,Variable in class java.awt.AWTEventMulticaster,
1,A,Static variable in class java.awt.PageAttribut...,"The MediaType instance for Engineering A, 8 1/..."
2,A,Static variable in class javax.print.attribute...,"Specifies the engineering A size, 8.5 inch by ..."
3,A,Static variable in class javax.print.attribute...,A size .
4,A,Static variable in class javax.swing.text.html...,
...,...,...,...
51185,_write(OutputStream),Method in class org.omg.PortableInterceptor.IO...,
51186,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,
51187,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,
51188,_write(OutputStream),Method in class org.omg.PortableInterceptor.Ob...,


In [None]:
def get_token(node):
    token = ''
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                yield from expand(item)
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence, api_sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    if token == 'MethodInvocation':
        api = [get_token(child) for child in children if not get_child(child)]
        # api_sequence.append(' '.join(api))
        if len(api) > 1:
            api_sequence.append(api[-1])
    for child in children:
        get_sequence(child, sequence, api_sequence)

In [None]:
def api_match(api_sequence, java_api):
    description_sequence = []
    for api in api_sequence:
        loc = java_api.loc[java_api['index_name'].str.contains(api, case=True)]
        if not loc.empty:
            description = loc['method_description'].iloc[0]
            if description != 'None':
                description_sequence.append(description)
    return description_sequence

In [None]:
def get_ast_and_description(data):
    description_sequence = []
    ast_sequence = []
    ast_sum = 0
    description_sum = 0
    data_size = len(data)
    for i in tqdm(range(data_size)):
        sequence = []
        api_sequence = []    
        get_sequence(parse_program(data['code'].iloc[i]), sequence, api_sequence)
        ast = ' '.join(sequence)
        ast_sequence.append(ast) 
        ast_sum += len(ast.split(' '))

        api_sequence = list(set(api_sequence)) 
        description = ' '.join(api_match(api_sequence, java_api)) 
        description_sequence.append(description) 
        description_sum += len(description.split(' '))
    print('ast average length', ast_sum/data_size)
    print('description average length', description_sum/data_size)
    return description_sequence, ast_sequence   

In [None]:
# multi-process
data_new = train_data_new
def multi_get_ast_and_des(l, i):
    sequence = []
    api_sequence = []    
    get_sequence(parse_program(data_new['code'].iloc[i]), sequence, api_sequence)
    ast = ' '.join(sequence)
    api_sequence = list(set(api_sequence)) 
    des = ' '.join(api_match(api_sequence, java_api)) 
    d = {'ast': ast, 'des': des, 'i': i}
    l.append(d)


manager = Manager()
data_size = len(data_new)
# print('data_size', data_size)
l = manager.list()
p = Pool(processes=30)
for i in range(data_size):
    p.apply_async(multi_get_ast_and_des, (l, i))
p.close()
p.join()

ast = []
des = []
i = []
for d in l[:]:
    ast.append(d['ast'].encode('utf-8','ignore').decode("utf-8"))
    des.append(d['des'].encode('utf-8','ignore').decode("utf-8"))
    i.append(d['i'])
d = {'ast': ast, 'des': des, 'i': i}
train_df = pd.DataFrame.from_dict(d)    
train_df

Unnamed: 0,ast,des,i
0,MethodDeclaration Modifier static public Refer...,,2
1,MethodDeclaration Modifier private BasicType i...,,3
2,MethodDeclaration Modifier public BasicType bo...,Returns true if this RenderingHints maps one o...,0
3,MethodDeclaration Modifier static protected Re...,Returns the current length of the sequence. Re...,1
4,ConstructorDeclaration EndWordAction FormalPar...,,13
...,...,...,...
69703,MethodDeclaration Modifier public Annotation O...,Atomically adds the given value to the current...,69660
69704,MethodDeclaration Modifier private ReferenceTy...,If true the component paints every pixel withi...,69656
69705,MethodDeclaration Modifier public BasicType in...,"Hides the splash screen, closes the window, an...",69702
69706,MethodDeclaration Modifier public startElement...,Writes a line separator. Returns the behavior ...,69685


In [None]:
# multi-process
data_new = test_data_new
def multi_get_ast_and_des(l, i):
    sequence = []
    api_sequence = []    
    get_sequence(parse_program(data_new['code'].iloc[i]), sequence, api_sequence)
    ast = ' '.join(sequence)
    api_sequence = list(set(api_sequence)) 
    des = ' '.join(api_match(api_sequence, java_api)) 
    d = {'ast': ast, 'des': des, 'i': i}
    l.append(d)


manager = Manager()
data_size = len(data_new)
# print('data_size', data_size)
l = manager.list()
p = Pool(processes=30)
for i in range(data_size):
    p.apply_async(multi_get_ast_and_des, (l, i))
p.close()
p.join()

ast = []
des = []
i = []
for d in l[:]:
    ast.append(d['ast'].encode('utf-8','ignore').decode("utf-8"))
    des.append(d['des'].encode('utf-8','ignore').decode("utf-8"))
    i.append(d['i'])
d = {'ast': ast, 'des': des, 'i': i}
test_df = pd.DataFrame.from_dict(d)    
test_df

Unnamed: 0,ast,des,i
0,MethodDeclaration Modifier static public Refer...,,2
1,MethodDeclaration Modifier private BasicType i...,,3
2,ConstructorDeclaration EndWordAction FormalPar...,,13
3,MethodDeclaration Modifier static public Refer...,Returns a string resulting from replacing all ...,32
4,MethodDeclaration Modifier public BasicType bo...,Checks whether the specified point is within t...,24
...,...,...,...
8709,MethodDeclaration Modifier private keysSetTest...,add Collection to set of Children (Unsupported...,8604
8710,MethodDeclaration Modifier static public ccDra...,Called by the context acceptor to process a to...,8696
8711,MethodDeclaration Modifier public put FormalPa...,Inserts the specified element at the end of th...,8635
8712,MethodDeclaration Modifier public draw FormalP...,Returns the coordinates and type of the curren...,8713


In [None]:
# multi-process
data_new = valid_data_new
def multi_get_ast_and_des(l, i):
    sequence = []
    api_sequence = []    
    get_sequence(parse_program(data_new['code'].iloc[i]), sequence, api_sequence)
    ast = ' '.join(sequence)
    api_sequence = list(set(api_sequence)) 
    des = ' '.join(api_match(api_sequence, java_api)) 
    d = {'ast': ast, 'des': des, 'i': i}
    l.append(d)


manager = Manager()
data_size = len(data_new)
# print('data_size', data_size)
l = manager.list()
p = Pool(processes=30)
for i in range(data_size):
    p.apply_async(multi_get_ast_and_des, (l, i))
p.close()
p.join()

ast = []
des = []
i = []
for d in l[:]:
    ast.append(d['ast'].encode('utf-8','ignore').decode("utf-8"))
    des.append(d['des'].encode('utf-8','ignore').decode("utf-8"))
    i.append(d['i'])
d = {'ast': ast, 'des': des , 'i': i}
valid_df = pd.DataFrame.from_dict(d)    
valid_df

Unnamed: 0,ast,des,i
0,MethodDeclaration Modifier public ReferenceTyp...,,0
1,MethodDeclaration Modifier static BasicType bo...,Gets the names of MBeans controlled by the MBe...,1
2,ConstructorDeclaration Modifier public WebServ...,,23
3,MethodDeclaration Modifier public overrideCurr...,Returns the current time in milliseconds.,14
4,MethodDeclaration Modifier private static init...,Called by the context acceptor to process a to...,9
...,...,...,...
8709,MethodDeclaration Modifier public ReferenceTyp...,Returns the object name of the MBean that cau...,8645
8710,MethodDeclaration Modifier public BasicType bo...,"Log a CONFIG message. Hides the splash screen,...",8617
8711,MethodDeclaration Modifier private BasicType i...,Visits an enum value in an annotation. Returns...,8713
8712,MethodDeclaration Modifier private _init Forma...,Creates a new instance of the class represente...,8710


In [None]:
train_df = train_df.sort_values(by=['i']).reset_index(drop=True)
train_data_new['ast'] = train_df['ast'].to_list()
train_data_new['des'] = train_df['des'].to_list()
train_data_new['ast_des'] = train_data_new['ast'] + ' ' + train_data_new['des']
train_data_new

In [None]:
valid_df = valid_df.sort_values(by=['i']).reset_index(drop=True)
valid_data_new['ast'] = valid_df['ast'].to_list()
valid_data_new['des'] = valid_df['des'].to_list()
valid_data_new['ast_des'] = valid_data_new['ast'] + ' ' + valid_data_new['des']
valid_data_new

In [None]:
test_df = test_df.sort_values(by=['i']).reset_index(drop=True)
test_data_new['ast'] = test_df['ast'].to_list()
test_data_new['des'] = test_df['des'].to_list()
test_data_new['ast_des'] = test_data_new['ast'] + ' ' + test_data_new['des']
test_data_new

In [None]:
valid_data_new.to_json(path_or_buf='/data/code/represent-code-in-human/data/TLC-SUM-enhanced/valid.jsonl',
                     orient='records', lines=True)

In [None]:
test_data_new.to_json(path_or_buf='/data/code/represent-code-in-human/data/TLC-SUM-enhanced/test.jsonl',
                     orient='records', lines=True)

In [None]:
train_data_new.to_json(path_or_buf='/data/code/represent-code-in-human/data/TLC-SUM-enhanced/train.jsonl',
                     orient='records', lines=True)


statistics

In [None]:
def statistics(data):
    ast_length = []
    for i in tqdm(range(len(data))):
        ast_length.append(len(data['ast'][i].split()))
    series = pd.Series(ast_length)
    print(series.describe())   

In [None]:
statistics(train_data_new)

100%|██████████| 69708/69708 [00:01<00:00, 50547.31it/s]


In [None]:
statistics(valid_data_new)

100%|██████████| 8714/8714 [00:00<00:00, 36333.09it/s]


In [None]:
statistics(test_data_new)

100%|██████████| 8714/8714 [00:00<00:00, 36396.95it/s]


write features

In [None]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = 'None'
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [None]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
ast_tokenizer = RobertaTokenizer.from_pretrained(checkpoint)
roberta = RobertaModel.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
config = RobertaConfig.from_pretrained(checkpoint)
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = ast_tokenizer.add_special_tokens(special_tokens_dict)

In [None]:
#  generate tree for AST Node
def create_tree(root, node, node_list, sub_id_list, leave_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node), get_child(node)

    if children == []:
        # print('this is a leaf:', token, id)
        leave_list.append(id)

    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node, 
    # and the other subtokens will be set as its new children
    token = token.encode('utf-8','ignore').decode("utf-8")   
    sub_token_list = tokenizer.tokenize(token)
    
    if id == 0:
        root.token = sub_token_list[0] # the root node is one of the tokenizer's special tokens
        root.data = node
        # record the num of nodes for every children of root
        root_children_node_num = []
        for child in children:
            node_num = len(node_list)
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=root)
            root_children_node_num.append(len(node_list) - node_num)        
        return root_children_node_num
    else:
        # print(sub_token_list)
        new_node = AnyNode(id=id, token=sub_token_list[0], data=node, parent=parent)
        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list[1:]:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)
        
        for child in children:
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=new_node)
    # print(token, id)

In [None]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    node_index_list.append(tokenizer.convert_tokens_to_ids(token))
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        if node.children: # some chidren are comprised by non-utf8 and will be removed
            variable_token_list.append(node.children[0].token)
            variable_id_list.append(node.children[0].id)   
    
    for child in node.children:
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list)

In [None]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, tokenizer):
    node_list = []
    sub_id_list = [] # record the ids of node that can be divide into multple subtokens
    leave_list = [] # record the ids of leave 
    new_tree = AnyNode(id=0, token=None, data=None)
    root_children_node_num = create_tree(new_tree, ast, node_list, sub_id_list, leave_list, tokenizer)
    # print('root_children_node_num', root_children_node_num)
    x = []
    edge_src = []
    edge_tgt = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, variable_token_list, variable_id_list)

    ast_edge_num = len(edge_src)
    edge_attr = [[0] for _ in range(ast_edge_num)]
    # set subtoken edge type to 2
    for i in range(len(edge_attr)):
        if edge_src[i] in sub_id_list and edge_tgt[i] in sub_id_list:
            edge_attr[i] = [2]
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    dataflow_edge_num = len(edge_src) - ast_edge_num

    # add next-token edge
    nexttoken_edge_num = len(leave_list)-1
    for i in range(nexttoken_edge_num):
        edge_src.append(leave_list[i])
        edge_tgt.append(leave_list[i+1])
        edge_src.append(leave_list[i+1])
        edge_tgt.append(leave_list[i])

    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    
    # set data flow edge type to 3
    for _ in range(nexttoken_edge_num * 2):
        edge_attr.append([3])
    
    return x, edge_index, edge_attr, root_children_node_num

In [None]:

def get_subgraph_node_num(root_children_node_num, divide_node_num):
    subgraph_node_num = []
    node_sum = 0
    real_graph_num = 0
    for num in root_children_node_num:
        node_sum += num
        if node_sum >= divide_node_num:
            subgraph_node_num.append(node_sum)
            node_sum = 0    
    
    subgraph_node_num.append(node_sum)
    real_graph_num = len(subgraph_node_num)

    if real_graph_num >= max_subgraph_num:
        return subgraph_node_num[: max_subgraph_num], max_subgraph_num

    # print(len(subgraph_node_num))
    # if the last subgraph node num < divide_node_num, then put the last subgraph to the second to last subgraph
    # if subgraph_node_num[-1] < divide_node_num:
    #     subgraph_node_num[-2] = subgraph_node_num[-2] + subgraph_node_num[-1]
    #     subgraph_node_num[-1] = 0
    #     real_graph_num -= 1

    # zero padding for tensor transforming
    for _ in range(real_graph_num, max_subgraph_num):
        subgraph_node_num.append(0)
    
    return subgraph_node_num, real_graph_num

In [None]:
def convert_examples_to_features(examples, ast_tokenizer, tokenizer, stage=None):
    features = []
    for example in tqdm(examples):
        # pyg
        ast = parse_program(example.source)
        x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast, ast_tokenizer)
        subgraph_node_num, real_graph_num = get_subgraph_node_num(root_children_node_num, divide_node_num)

        # source
        source_tokens = tokenizer.tokenize(example.ast_des)[: max_source_length-2]
        source_tokens = [tokenizer.cls_token] + source_tokens + [tokenizer.sep_token]
        source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
        source_mask = [1] * (len(source_ids))
        padding_length = max_source_length - len(source_ids)
        source_ids += [tokenizer.pad_token_id] * padding_length
        source_mask += [0] * padding_length

        # target
        if stage == 'test':
            target_tokens = tokenizer.tokenize('None')
        else:
            target_tokens = tokenizer.tokenize(example.target)[: max_target_length-2]
        target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
        target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
        target_mask = [1] * len(target_ids)
        padding_length = max_target_length - len(target_ids)
        target_ids += [tokenizer.pad_token_id] * padding_length
        target_mask += [0] * padding_length

        features.append(
            Data(
                x= torch.tensor(x, dtype=torch.long),
                edge_index=torch.tensor(edge_index, dtype=torch.long),
                edge_attr=torch.tensor(edge_attr, dtype=torch.long),
                source_ids=torch.tensor(source_ids, dtype=torch.long),
                source_mask=torch.tensor(source_mask, dtype=torch.long),
                target_ids=torch.tensor(target_ids, dtype=torch.long),
                target_mask=torch.tensor(target_mask, dtype=torch.long),
                subgraph_node_num=torch.tensor(subgraph_node_num, dtype=torch.long),
                real_graph_num=torch.tensor(real_graph_num, dtype=torch.long)
            )
        )
    return features

In [None]:
class Example(object):
    def __init__(self, idx, source, ast_des, target):
        self.idx = idx
        self.source = source
        self.ast_des = ast_des
        self.target = target

In [None]:
# read dataset
def read_examples(filename):
    examples = []
    with open(filename, encoding='utf-8') as f:
        for idx, line in enumerate(f):
            line = line.strip()
            js = json.loads(line)
            if 'idx' not in js:
                js['idx'] = idx
            
            code = js['code']
            nl = ' '.join(js['docstring_tokens']).replace('\n', '')
            nl = ' '.join(nl.strip().split())
            ast_des = js['ast_des']
            examples.append(
                Example(
                    idx = idx,
                    source = code,
                    ast_des = ast_des,
                    target = nl,
                )
            )
    return examples

In [None]:
train_examples = read_examples('/data/code/represent-code-in-human/data/TLC-SUM-enhanced/train.jsonl')
valid_examples = read_examples('/data/code/represent-code-in-human/data/TLC-SUM-enhanced/valid.jsonl')
test_examples = read_examples('/data/code/represent-code-in-human/data/TLC-SUM-enhanced/test.jsonl')

In [None]:
train_x = []
valid_x = []
test_x = []
for example in train_examples:
    ast = parse_program(example.source)
    x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast, ast_tokenizer)
    train_x.append(len(x))

for example in valid_examples:
    ast = parse_program(example.source)
    x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast, ast_tokenizer)
    valid_x.append(len(x))

for example in test_examples:
    ast = parse_program(example.source)
    x, edge_index, edge_attr, root_children_node_num = get_pyg_data_from_ast(ast, ast_tokenizer)
    test_x.append(len(x))

all_x = train_x + valid_x + test_x

print(sum(train_x)/len(train_x), sum(valid_x)/len(valid_x), sum(test_x)/len(test_x), sum(all_x)/len(all_x))

In [None]:
train_features = convert_examples_to_features(train_examples, ast_tokenizer, tokenizer, stage='train')
valid_features = convert_examples_to_features(valid_examples, ast_tokenizer, tokenizer, stage='valid')
test_features = convert_examples_to_features(test_examples, ast_tokenizer, tokenizer, stage='test')

100%|██████████| 69708/69708 [26:31<00:00, 43.81it/s]
100%|██████████| 8714/8714 [03:20<00:00, 41.54it/s]
100%|██████████| 8714/8714 [03:17<00:00, 44.09it/s]


In [None]:
# torch.save(train_features,'features/tlc/train_features.pt')
# torch.save(valid_features,'features/tlc/valid_features.pt')
# torch.save(test_features,'features/tlc/test_features.pt')

In [None]:
len(train_features)

In [None]:
len(valid_features)

In [None]:
len(test_features)