In [17]:
from transformers import RobertaTokenizer
import javalang
from javalang.ast import Node
from anytree import AnyNode
import torch

In [2]:
test_func = '''public int getLarger(int a, int b) {
		a = Math.abs(a);
		b = Math.abs(b);
		if(a > b) {
			return a;
		}else {
			return b;
		}
	}
'''

In [3]:
checkpoint = 'microsoft/codebert-base'
tokenizer = RobertaTokenizer.from_pretrained(checkpoint)

In [4]:
javalang_special_tokens = ['CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration',
                            'ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type',
                            'BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair',
                            'ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration',
                            'ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration',
                            'VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement',
                            'WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement',
                            'ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement',
                            'BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter',
                            'SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression',
                            'BinaryOperation','Cast','MethodReference','LambdaExpression','Primary','Literal','This',
                            'MemberReference','Invocation','ExplicitConstructorInvocation','SuperConstructorInvocation',
                            'MethodInvocation','SuperMethodInvocation','SuperMemberReference','ArraySelector','ClassReference',
                            'VoidClassReference','Creator','ArrayCreator','ClassCreator','InnerClassCreator','EnumBody',
                            'EnumConstantDeclaration','AnnotationMethod', 'Modifier']

In [5]:
special_tokens_dict = {'additional_special_tokens': javalang_special_tokens}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)

In [6]:
# use javalang to generate ASTs and depth-first traverse to generate ast nodes corpus
def get_token(node):
    token = 'None'
    if isinstance(node, str):
        token = node
    elif isinstance(node, set):
        token = 'Modifier'
    elif isinstance(node, Node):
        token = node.__class__.__name__
    return token


def get_child(root):
    if isinstance(root, Node):
        children = root.children
    elif isinstance(root, set):
        children = list(root)
    else:
        children = []

    def expand(nested_list):
        for item in nested_list:
            if isinstance(item, list):
                for sub_item in expand(item):
                    yield sub_item
            elif item:
                yield item

    return list(expand(children))


def get_sequence(node, sequence):
    token, children = get_token(node), get_child(node)
    sequence.append(token)
    for child in children:
        get_sequence(child, sequence)


def parse_program(func):
    tokens = javalang.tokenizer.tokenize(func)
    parser = javalang.parser.Parser(tokens)
    tree = parser.parse_member_declaration()
    return tree

In [7]:
#  generate tree for AST Node
def create_tree(root, node, node_list, sub_id_list, leave_list, tokenizer, parent=None):
    id = len(node_list)
    node_list.append(node)
    token, children = get_token(node), get_child(node)

    if children == []:
        # print('this is a leaf:', token, id)
        leave_list.append(id)

    # Use roberta.tokenizer to generate subtokens
    # If a token can be divided into multiple(>1) subtokens, the first subtoken will be set as the previous node, 
    # and the other subtokens will be set as its new children   
    sub_token_list = tokenizer.tokenize(token)
    # print(id)
    if id == 0:
        root.token = sub_token_list[0] # the root node is one of the tokenizer's special tokens
        root.data = node
        for child in children:
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=root)
    else:
        # print(sub_token_list)
        new_node = AnyNode(id=id, token=sub_token_list[0], data=node, parent=parent)
        if len(sub_token_list) > 1:
            sub_id_list.append(id)
            for sub_token in sub_token_list[1:]:
                id += 1
                AnyNode(id=id, token=sub_token, data=node, parent=new_node)
                node_list.append(sub_token)
                sub_id_list.append(id)
        
        for child in children:
            create_tree(root, child, node_list, sub_id_list, leave_list, tokenizer, parent=new_node)
    
    

In [8]:
# traverse the AST tree to get all the nodes and edges
def get_node_and_edge(node, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list):
    token = node.token
    # print('id', node.id, 'token', node.token)
    node_index_list.append(tokenizer.convert_tokens_to_ids(token))
    # node_index_list.append([vocab_dict.word2id.get(token, UNK)])
    # find out all variables
    if token in ['VariableDeclarator', 'MemberReference']:
        variable_token_list.append(node.children[0].token)
        variable_id_list.append(node.children[0].id)   
    for child in node.children:
        src.append(node.id)
        tgt.append(child.id)
        src.append(child.id)
        tgt.append(node.id)
        get_node_and_edge(child, node_index_list, tokenizer, src, tgt, variable_token_list, variable_id_list)

In [9]:
# generate pytorch_geometric input format data from ast
def get_pyg_data_from_ast(ast, tokenizer):
    node_list = []
    sub_id_list = [] # record the ids of node that can be divide into multple subtokens
    leave_list = [] # record the ids of leave 
    new_tree = AnyNode(id=0, token=None, data=None)
    create_tree(new_tree, ast, node_list, sub_id_list, leave_list, tokenizer)
    x = []
    edge_src = []
    edge_tgt = []
    # record variable tokens and ids to add data flow edge in AST graph
    variable_token_list = []
    variable_id_list = []
    get_node_and_edge(new_tree, x, tokenizer, edge_src, edge_tgt, variable_token_list, variable_id_list)

    ast_edge_num = len(edge_src)
    edge_attr = [[0] for _ in range(ast_edge_num)]
    # set subtoken edge type to 2
    for i in range(len(edge_attr)):
        if edge_src[i] in sub_id_list and edge_tgt[i] in sub_id_list:
            edge_attr[i] = [2]
    # add data flow edge
    variable_dict = {}
    for i in range(len(variable_token_list)):
        # print('variable_dict', variable_dict)
        if variable_token_list[i] not in variable_dict:
            variable_dict.setdefault(variable_token_list[i], variable_id_list[i])
        else:
            # print('edge', variable_dict.get(variable_token_list[i]), variable_id_list[i])
            edge_src.append(variable_dict.get(variable_token_list[i]))
            edge_tgt.append(variable_id_list[i])
            edge_src.append(variable_id_list[i])
            edge_tgt.append(variable_dict.get(variable_token_list[i]))
            variable_dict[variable_token_list[i]] = variable_id_list[i]
    dataflow_edge_num = len(edge_src) - ast_edge_num

    # add next-token edge
    nexttoken_edge_num = len(leave_list)-1
    for i in range(nexttoken_edge_num):
        edge_src.append(leave_list[i])
        edge_tgt.append(leave_list[i+1])
        edge_src.append(leave_list[i+1])
        edge_tgt.append(leave_list[i])

    edge_index = [edge_src, edge_tgt]

    # set data flow edge type to 1
    for _ in range(dataflow_edge_num):
        edge_attr.append([1])
    
    # set data flow edge type to 3
    for _ in range(nexttoken_edge_num * 2):
        edge_attr.append([3])
    
    return x, edge_index, edge_attr

In [10]:
x, edge_index, edge_attr = get_pyg_data_from_ast(ast=parse_program(func=test_func), tokenizer=tokenizer)

In [16]:
divide_node_num = 2

def get_subgraph_info_from_batch(batch):
    comp = 0
    pos = 0
    graph_pos_in_batch = [0] # record begin positions and end positions of every subgraph
    graph_length = [] # use a list to store the node nums in subgraph
    for i in range(len(batch)):
        if batch[i] != comp:
            graph_pos_in_batch.append(i)
            graph_length.append(i-pos)
            comp = batch[i]
            pos = i
    graph_length.append(len(batch)-pos)
    return graph_pos_in_batch, graph_length

def partition_graph(x, edge_index, edge_attr, batch):        
    nodes_list = [] # record all nodes number for each subgraph in total batch
    graph_pos_in_batch, graph_length = get_subgraph_info_from_batch(batch)
    max_seq_len = max(graph_length)
    subgraph_num = int(max_seq_len/divide_node_num) + 1
    for i in range(subgraph_num):
        nodes = []
        for j in range(len(graph_pos_in_batch)):
            if graph_length[j] > i * divide_node_num:
                if graph_length[j] > (i+1) * divide_node_num:
                    subgraph_len = divide_node_num
                else:
                    subgraph_len = graph_length[j] - i * divide_node_num   
                for m in range(subgraph_len):
                    nodes.append(graph_pos_in_batch[j] + m)          
        nodes_list.append(set(nodes)) 
    # only count the edge whose target node in subgraph
    sub_edge_src = [[] for _ in range(subgraph_num)]
    sub_edge_tgt = [[] for _ in range(subgraph_num)]
    sub_edge_attr = [[] for _ in range(subgraph_num)]
    # print('nodes_list', nodes_list)
    node_num = len(x)
    node_subgraph_index = [0 for _ in range(node_num)] # use a list to store the subgraph numbers for all nodes
    for i in range(len(nodes_list)):
        for node in nodes_list[i]:
            node_subgraph_index[node] = i

    for i in range(len(edge_index[1])):
        src = edge_index[0][i].item()
        tgt = edge_index[1][i].item()
        sub_edge_src[node_subgraph_index[tgt]].append(src)
        sub_edge_tgt[node_subgraph_index[tgt]].append(tgt)
        sub_edge_attr[node_subgraph_index[tgt]].append(edge_attr[i].item())
    edge_index_list = []
    edge_attr_list = []
    for i in range(subgraph_num):
        edge_index_list.append(torch.tensor([sub_edge_src[i], sub_edge_tgt[i]], dtype=torch.long))
        edge_attr_list.append(torch.tensor(sub_edge_attr[i], dtype=torch.long))
    return edge_index_list, edge_attr_list

[[0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [2],
 [2],
 [2],
 [2],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [0],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [1],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3],
 [3]]

Use regex to match all special tokens in Javalang

In [12]:
ast_node = '''
from .ast import Node

# ------------------------------------------------------------------------------

class CompilationUnit(Node):
    attrs = ("package", "imports", "types")

class Import(Node):
    attrs = ("path", "static", "wildcard")

class Documented(Node):
    attrs = ("documentation",)

class Declaration(Node):
    attrs = ("modifiers", "annotations")

class TypeDeclaration(Declaration, Documented):
    attrs = ("name", "body")

    @property
    def fields(self):
        return [decl for decl in self.body if isinstance(decl, FieldDeclaration)]

    @property
    def methods(self):
        return [decl for decl in self.body if isinstance(decl, MethodDeclaration)]

    @property
    def constructors(self):
        return [decl for decl in self.body if isinstance(decl, ConstructorDeclaration)]

class PackageDeclaration(Declaration, Documented):
    attrs = ("name",)

class ClassDeclaration(TypeDeclaration):
    attrs = ("type_parameters", "extends", "implements")

class EnumDeclaration(TypeDeclaration):
    attrs = ("implements",)

    @property
    def fields(self):
        return [decl for decl in self.body.declarations if isinstance(decl, FieldDeclaration)]

    @property
    def methods(self):
        return [decl for decl in self.body.declarations if isinstance(decl, MethodDeclaration)]

class InterfaceDeclaration(TypeDeclaration):
    attrs = ("type_parameters", "extends",)

class AnnotationDeclaration(TypeDeclaration):
    attrs = ()

# ------------------------------------------------------------------------------

class Type(Node):
    attrs = ("name", "dimensions",)

class BasicType(Type):
    attrs = ()

class ReferenceType(Type):
    attrs = ("arguments", "sub_type")

class TypeArgument(Node):
    attrs = ("type", "pattern_type")

# ------------------------------------------------------------------------------

class TypeParameter(Node):
    attrs = ("name", "extends")

# ------------------------------------------------------------------------------

class Annotation(Node):
    attrs = ("name", "element")

class ElementValuePair(Node):
    attrs = ("name", "value")

class ElementArrayValue(Node):
    attrs = ("values",)

# ------------------------------------------------------------------------------

class Member(Documented):
    attrs = ()

class MethodDeclaration(Member, Declaration):
    attrs = ("type_parameters", "return_type", "name", "parameters", "throws", "body")

class FieldDeclaration(Member, Declaration):
    attrs = ("type", "declarators")

class ConstructorDeclaration(Declaration, Documented):
    attrs = ("type_parameters", "name", "parameters", "throws", "body")

# ------------------------------------------------------------------------------

class ConstantDeclaration(FieldDeclaration):
    attrs = ()

class ArrayInitializer(Node):
    attrs = ("initializers",)

class VariableDeclaration(Declaration):
    attrs = ("type", "declarators")

class LocalVariableDeclaration(VariableDeclaration):
    attrs = ()

class VariableDeclarator(Node):
    attrs = ("name", "dimensions", "initializer")

class FormalParameter(Declaration):
    attrs = ("type", "name", "varargs")

class InferredFormalParameter(Node):
    attrs = ('name',)

# ------------------------------------------------------------------------------

class Statement(Node):
    attrs = ("label",)

class IfStatement(Statement):
    attrs = ("condition", "then_statement", "else_statement")

class WhileStatement(Statement):
    attrs = ("condition", "body")

class DoStatement(Statement):
    attrs = ("condition", "body")

class ForStatement(Statement):
    attrs = ("control", "body")

class AssertStatement(Statement):
    attrs = ("condition", "value")

class BreakStatement(Statement):
    attrs = ("goto",)

class ContinueStatement(Statement):
    attrs = ("goto",)

class ReturnStatement(Statement):
    attrs = ("expression",)

class ThrowStatement(Statement):
    attrs = ("expression",)

class SynchronizedStatement(Statement):
    attrs = ("lock", "block")

class TryStatement(Statement):
    attrs = ("resources", "block", "catches", "finally_block")

class SwitchStatement(Statement):
    attrs = ("expression", "cases")

class BlockStatement(Statement):
    attrs = ("statements",)

class StatementExpression(Statement):
    attrs = ("expression",)

# ------------------------------------------------------------------------------

class TryResource(Declaration):
    attrs = ("type", "name", "value")

class CatchClause(Statement):
    attrs = ("parameter", "block")

class CatchClauseParameter(Declaration):
    attrs = ("types", "name")

# ------------------------------------------------------------------------------

class SwitchStatementCase(Node):
    attrs = ("case", "statements")

class ForControl(Node):
    attrs = ("init", "condition", "update")

class EnhancedForControl(Node):
    attrs = ("var", "iterable")

# ------------------------------------------------------------------------------

class Expression(Node):
    attrs = ()

class Assignment(Expression):
    attrs = ("expressionl", "value", "type")

class TernaryExpression(Expression):
    attrs = ("condition", "if_true", "if_false")

class BinaryOperation(Expression):
    attrs = ("operator", "operandl", "operandr")

class Cast(Expression):
    attrs = ("type", "expression")

class MethodReference(Expression):
    attrs = ("expression", "method", "type_arguments")

class LambdaExpression(Expression):
    attrs = ('parameters', 'body')

# ------------------------------------------------------------------------------

class Primary(Expression):
    attrs = ("prefix_operators", "postfix_operators", "qualifier", "selectors")

class Literal(Primary):
    attrs = ("value",)

class This(Primary):
    attrs = ()

class MemberReference(Primary):
    attrs = ("member",)

class Invocation(Primary):
    attrs = ("type_arguments", "arguments")

class ExplicitConstructorInvocation(Invocation):
    attrs = ()

class SuperConstructorInvocation(Invocation):
    attrs = ()

class MethodInvocation(Invocation):
    attrs = ("member",)

class SuperMethodInvocation(Invocation):
    attrs = ("member",)

class SuperMemberReference(Primary):
    attrs = ("member",)

class ArraySelector(Expression):
    attrs = ("index",)

class ClassReference(Primary):
    attrs = ("type",)

class VoidClassReference(ClassReference):
    attrs = ()

# ------------------------------------------------------------------------------

class Creator(Primary):
    attrs = ("type",)

class ArrayCreator(Creator):
    attrs = ("dimensions", "initializer")

class ClassCreator(Creator):
    attrs = ("constructor_type_arguments", "arguments", "body")

class InnerClassCreator(Creator):
    attrs = ("constructor_type_arguments", "arguments", "body")

# ------------------------------------------------------------------------------

class EnumBody(Node):
    attrs = ("constants", "declarations")

class EnumConstantDeclaration(Declaration, Documented):
    attrs = ("name", "arguments", "body")

class AnnotationMethod(Declaration):
    attrs = ("name", "return_type", "dimensions", "default")
'''

In [13]:
import re

result = re.findall('.*class (.*)\(.*', ast_node)

In [14]:
'\',\''.join(result)

"CompilationUnit','Import','Documented','Declaration','TypeDeclaration','PackageDeclaration','ClassDeclaration','EnumDeclaration','InterfaceDeclaration','AnnotationDeclaration','Type','BasicType','ReferenceType','TypeArgument','TypeParameter','Annotation','ElementValuePair','ElementArrayValue','Member','MethodDeclaration','FieldDeclaration','ConstructorDeclaration','ConstantDeclaration','ArrayInitializer','VariableDeclaration','LocalVariableDeclaration','VariableDeclarator','FormalParameter','InferredFormalParameter','Statement','IfStatement','WhileStatement','DoStatement','ForStatement','AssertStatement','BreakStatement','ContinueStatement','ReturnStatement','ThrowStatement','SynchronizedStatement','TryStatement','SwitchStatement','BlockStatement','StatementExpression','TryResource','CatchClause','CatchClauseParameter','SwitchStatementCase','ForControl','EnhancedForControl','Expression','Assignment','TernaryExpression','BinaryOperation','Cast','MethodReference','LambdaExpression','Pri

In [15]:
# try:   
#     !jupyter nbconvert --to python subword_test.ipynb
#     # python即转化为.py，script即转化为.html
#     # file_name.ipynb即当前module的文件名
# except:
#     pass
